facebookresearch / vissl

VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
https://vissl.ai
MIT License
3.24k stars 330 forks source link

Incompatibilty weights #550

Closed VicaYang closed 2 years ago

VicaYang commented 2 years ago

In vissl/utils/checkpoint.py:743, there are codes to check the weights. You can see that it checks _feature_blocks.

    trunk_append_prefix, heads_append_prefix = "trunk._feature_blocks.", "heads."
    if is_feature_extractor_model(config.MODEL):
        trunk_append_prefix = "trunk.base_model._feature_blocks."

    is_compatible = True
    for layername in state_dict.keys():
        if not (
            layername.startswith(trunk_append_prefix)
            or layername.startswith(heads_append_prefix)
        ):
            is_compatible = False
            break

However, some VISSL model does not have this structure, such as ViT, you can download weights from here to check. While some archs like ResNet50 do have, you can download weights from here to check.

In [4]: list(torch.load('model_final_checkpoint_phase208.torch')['classy_state_dict']['base_model']['model']['trunk'].ke
   ...: ys())[:10]
Out[4]:
['_feature_blocks.conv1.weight',
 '_feature_blocks.bn1.weight',
 '_feature_blocks.bn1.bias',
 '_feature_blocks.bn1.running_mean',
 '_feature_blocks.bn1.running_var',
 '_feature_blocks.bn1.num_batches_tracked',
 '_feature_blocks.layer1.0.conv1.weight',
 '_feature_blocks.layer1.0.bn1.weight',
 '_feature_blocks.layer1.0.bn1.bias',
 '_feature_blocks.layer1.0.bn1.running_mean']

In [5]: list(torch.load('vit_b16_p16_in22k_ep90_supervised.torch')['classy_state_dict']['base_model']['model']['trunk'].
   ...: keys())[:10]
Out[5]:
['cls_token',
 'pos_embed',
 'patch_embed.proj.weight',
 'patch_embed.proj.bias',
 'blocks.0.norm1.weight',
 'blocks.0.norm1.bias',
 'blocks.0.attn.qkv.weight',
 'blocks.0.attn.qkv.bias',
 'blocks.0.attn.proj.weight',
 'blocks.0.attn.proj.bias']
VicaYang commented 2 years ago

I would suggest removing this check, and give a warning if the state_dict is not compatible with model.keys(), instead of hard-code "_feature_blocks"

VicaYang commented 2 years ago

Well, it seems that this function has been modified in the main branch, but not in the pre-built library and v.0.1.6 branch (which is suggested in the tutorials)