jlianglab / Ark

Other
17 stars 1 forks source link

Error loading pretrained weights--key mismatch #1

Closed leibyj closed 8 months ago

leibyj commented 8 months ago

When I try to load the pretrained weights into the timm model:

import torch
import timm
print(timm.__version__)

model = timm.create_model('swin_base_patch4_window7_224', num_classes=3, pretrained=False)

state_dict = torch.load('/Users/leibyj/Downloads/ark6_teacher_ep200_swinb_projector1376_mlp.pth.tar', map_location="cpu")
for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
    if k in state_dict:
        print(f"Removing key {k} from pretrained checkpoint")
        del state_dict[k] 
model.load_state_dict(state_dict, strict=False)

I get the following warning:

0.5.4
Removing key head.weight from pretrained checkpoint
Removing key head.bias from pretrained checkpoint
_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=['projector.0.weight', 'projector.0.bias', 'projector.2.weight', 'projector.2.bias', 'omni_heads.0.weight', 'omni_heads.0.bias', 'omni_heads.1.weight', 'omni_heads.1.bias', 'omni_heads.2.weight', 'omni_heads.2.bias', 'omni_heads.3.weight', 'omni_heads.3.bias', 'omni_heads.4.weight', 'omni_heads.4.bias', 'omni_heads.5.weight', 'omni_heads.5.bias'])

Are the projector and omni_heads specific to the ARK pretraining--and irrelevant to loading the weights?

Mda233 commented 6 months ago

Yes, you are correct. the projector and omni_heads specific to the ARK pretraining and not a part of Swin Transformer. Sorry for the late reply. Normally I should receive email for any issues but I didn't receive any notification for this issue.