Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add documentation and helper functions to store and retrieve checkpoints #127

Closed
IgorSusmelj opened this issue Feb 13, 2021 · 6 comments
Closed

Comments

@IgorSusmelj
Copy link
Contributor

We have a model zoo but no documentation on how to load the models from it. We also have the PyTorch Lightning wrapper for training the models but no information on how to use the models afterward.

It would be great to provide the following (documentation and if it makes sense helper functions):

How to load a model that has been saved by PyTorch Lightning (e.g. I just want to get the ResNet backbone from it)?

# run lightly from CLI or using the PyTorch Lightning wrapper
...
# now you should have a folder with a checkpoint
# let's load the checkpoint in another script (e.g. to do transfer learning)
ckpt = torch.load('my_lightning_checkpoint.pth')
my_resnet = lightly.models.ResNetGenerator()

# load checkpoint state dict to my_resnet
... # TODO: how to load ckpt state dict into my_resnet

How to manually save and load a model?

# simple example of storing and reading a state dict
backbone = lightly.models.ResNetGenerator()
simclr_model = lightly.models.SimCLR(backbone)

# save weights from backbone
torch.save({'model': simclr_model.state_dict()}, 'my_checkpoint.pth')

# load the backbone later (can be in another script)
ckpt = torch.load('my_checkpoint.pth'')
backbone = lightly.models.ResNetGenerator()
backbone.load_state_dict(ckpt['model'])
@Nike682631
Copy link

Is this issue free to work on?

@philippmwirth
Copy link
Contributor

Yes, it is. However, parts may change with the ongoing refactoring so I'd recommend to work on something else atm.

@Nike682631
Copy link

Nike682631 commented Sep 20, 2021 via email

@shree-lily
Copy link

Has this been added to the documentation?

@guarin
Copy link
Contributor

guarin commented Sep 22, 2022

Hi @shree-lily, we sadly did not add this to the documentation yet.

You can still use either pytorch or pytorch lightning to save and load checkpoints:

Taking our example code for the SimCLR model here: https://docs.lightly.ai/examples/simclr.html You can save and load the model as follows:

...

# saving
model = SimCLR(backbone)
torch.save(model.state_dict(), 'simclr_model.ckpt')

# loading
model = SimCLR(backbone)
model.load_state_dict(torch.load('simclr_model.ckpt'))

It is common to drop the projection heads for downstream or inference tasks and only use the backbone instead. You can get the backbone with the model.backbone attribute.

@guarin
Copy link
Contributor

guarin commented Aug 16, 2024

We now have a tutorial on how to finetune checkpoints from lightly: https://docs.lightly.ai/self-supervised-learning/tutorials/package/tutorial_checkpoint_finetuning.html

We also plan to upload the benchmark backbones: #1621

@guarin guarin closed this as completed Aug 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

No branches or pull requests

5 participants