Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix optimizer reloading from checkpoint #1329

Merged
merged 2 commits into from
Aug 15, 2024
Merged

Fix optimizer reloading from checkpoint #1329

merged 2 commits into from
Aug 15, 2024

Conversation

awni
Copy link
Member

@awni awni commented Aug 14, 2024

This is kind of a weird problem.

Basically if you have a model like:

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(2, 2)
        self.drop = nn.Dropout(p=0.5)
        self.l2 = nn.Linear(2, 2)

It has a parameters() like {"l1": ..., "drop": {}, "l2": ...}. The optimizer state is setup to match that. When serializing the optimizer state it gets flattened which removes the empty drop value from the tree. That in turn causes issues when deserializing the optimizer because its' state no longer matches that of parameters().

I think there are two possible fixes here, I implemented the first, but I'm open to trying the second:

  • Delay completing the initialization of the optimizer until we have the first parameter state so we can match it.
  • Change Module.parameters() and Module.trainable_parameters() to not include empty values.

This closes #1328

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the one comment I think this looks good.

I am not too happy that the tree walking logic starts crawling into everything but so far it is not too annoying.

python/mlx/optimizers/optimizers.py Show resolved Hide resolved
@awni awni merged commit ae5b5ca into main Aug 15, 2024
3 checks passed
@awni awni deleted the fix_opt_state branch August 15, 2024 14:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Unable to load from a saved checkpoint, KeyError for all dropout modules...
2 participants