syp2ysy / VRP-SAM

[CVPR 2024] Official implementation of "VRP-SAM: SAM with Visual Reference Prompt"
MIT License
100 stars 11 forks source link

Loading ResNet #3

Closed bestrobotplans closed 7 months ago

bestrobotplans commented 7 months ago

I have a question regarding the code. In resnet.py starting around line 318, the code says:

model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
    #model.load_state_dict(torch.utils.model_zoo.load_url(model_urls['resnet50']))
    model_path = '/root/paddlejob/workspace/env_run/vrp_sam/resnet50_v2.pth'
    model.load_state_dict(torch.load(model_path), strict=False)

This fails because I have nothing at root/paddlejob/workspace/env_run/vrp_sam/resnet50_v2.pth.

So I try to change it to:

model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
    model.load_state_dict(torch.utils.model_zoo.load_url(model_urls['resnet50']))
    #model_path = '/root/paddlejob/workspace/env_run/vrp_sam/resnet50_v2.pth'
    #model.load_state_dict(torch.load(model_path), strict=False)

But now I get the following error:

RuntimeError: Error(s) in loading state_dict for ResNet: Missing key(s) in state_dict: "conv2.weight", "bn2.weight", "bn2.bias", "bn2.running_mean", "bn2.running_var", "conv3.weight", "bn3.weight", "bn3.bias", "bn3.running_mean", "bn3.running_var". size mismatch for conv1.weight: copying a param with shape torch.Size([64, 3, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]). size mismatch for layer1.0.conv1.weight: copying a param with shape torch.Size([64, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 128, 1, 1]). size mismatch for layer1.0.downsample.0.weight: copying a param with shape torch.Size([256, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 128, 1, 1]).

Could you provide help on how to get a ResNet model working here?

syp2ysy commented 7 months ago

This is because we are using ResNetv2. Here is the link to download the weights: https://drive.google.com/file/d/1w5pRmLJXvmQQA5PtCbHhZc_uC4o0YbmA/view

ImmortalSdm commented 6 months ago

This is because we are using ResNetv2. Here is the link to download the weights: https://drive.google.com/file/d/1w5pRmLJXvmQQA5PtCbHhZc_uC4o0YbmA/view

I do think u should change the code for clarification~