We should add a function to load backbones from the benchmark checkpoints. The function should roughly do the following:
from torchvision.models import resnet50
from torch.hub import load_state_dict_from_url
model = resnet50()
state_dict = load_state_dict_from_url("https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt")
new_state_dict = {}
for key, value in state_dict["state_dict"].items():
if key.startswith("backbone."):
new_state_dict[key.lstrip("backbone.")] = value
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict)
assert missing_keys == {"fc.weight", "fc.bias"}
Maybe we can leave the load_state_dict_from_url outside the function make the function just take a state dict as input and return the new state dict as output.
We should add a function to load backbones from the benchmark checkpoints. The function should roughly do the following:
Maybe we can leave the
load_state_dict_from_url
outside the function make the function just take a state dict as input and return the new state dict as output.TODO