lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
3.09k stars 264 forks source link

Add function to load model from pretrained checkpoint #1475

Closed guarin closed 1 month ago

guarin commented 8 months ago

We should add a function to load backbones from the benchmark checkpoints. The function should roughly do the following:

from torchvision.models import resnet50
from torch.hub import load_state_dict_from_url

model = resnet50()
state_dict = load_state_dict_from_url("https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt")
new_state_dict = {}
for key, value in state_dict["state_dict"].items():
     if key.startswith("backbone."):
        new_state_dict[key.lstrip("backbone.")] = value
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict)
assert missing_keys == {"fc.weight", "fc.bias"}

Maybe we can leave the load_state_dict_from_url outside the function make the function just take a state dict as input and return the new state dict as output.

TODO

guarin commented 1 month ago

Closed in favor of https://github.com/lightly-ai/lightly/issues/1621