huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
30.66k stars 4.63k forks source link

[BUG] Loading state dict in a feature extraction network #2215

Open ioangatop opened 1 week ago

ioangatop commented 1 week ago

Describe the bug

Hi Ross! I'm facing a small issue with the features extractor, here are some details:

The function create_model supports the argument of checkpoint_path which allows to load custom model weights. However, when we want to load a model as feature extractor, the model is wrapped around the FeatureGetterNet class, and the loading fails as the keys do not much anymore; the FeatureGetterNet stores the model under self.model so in order to work, the state dict keys should have a prefix model., for example class_token -> model.class_token

Additionally, one workaround is to do the loading of the model after the initialisation, but this also fails as some networks, like vision transformer, prune some layers and thus the state_dict has extra keys

To Reproduce

from urllib import request

from timm.models import _helpers
import timm

# download weights
request.urlretrieve("https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", "dino_deitsmall16_pretrain.pth")

# build and load model -- works as expected
model = timm.create_model(
    model_name="vit_small_patch16_224",
    num_classes=0,
    checkpoint_path="dino_deitsmall16_pretrain.pth",
)

# RuntimeError: Error(s) in loading state_dict for FeatureGetterNet:
#   Missing key(s) in state_dict: "model.cls_token", "model.pos_embed", ...
#  Unexpected key(s) in state_dict: "cls_token", "pos_embed", ...
backbone = timm.create_model(
    model_name="vit_small_patch16_224",
    num_classes=0,
    features_only=True,
    checkpoint_path="dino_deitsmall16_pretrain.pth",
)

# RuntimeError: Error(s) in loading state_dict for VisionTransformer:
#   Unexpected key(s) in state_dict: "norm.weight", "norm.bias". 
backbone = timm.create_model(
    model_name="vit_small_patch16_224",
    num_classes=0,
    features_only=True,
)
_helpers.load_checkpoint(backbone.model, "dino_deitsmall16_pretrain.pth")

As always, thanks a lot 🙏

rwightman commented 1 week ago

@ioangatop if you want classifier weights loaded into feature extraction wrapped models, you need to load weights as 'pretrained' so that they are loaded before the model is mutated.

See related discussion, should work with >= 0.9 timm version https://github.com/hugginface/pytorch-image-models/discussions/1941

Although, example in that discussion should be a bit differentl, use the 'overlay' arg as in the train script https://github.com/huggingface/pytorch-image-models/blob/d4ef0b4d589c9b0cb1d6240ff373c5508dbb8023/train.py#L463-L468

The overlay dict is merged with the models normal pretrained_cfg, the pretrained_cfg arg fully overrides it.

Alternative to using the file key in the pretrained_cfg override dict, you can also use url to download from somewhere else, or hf_hub_id for a HF hub location.