huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
32.3k stars 4.76k forks source link

[BUG] Get "UnpicklingError: invalid load key, '='." when calling "timm.create_model('beit_base_patch16_224', pretrained=True)" #1263

Closed IndoorAdventurer closed 2 years ago

IndoorAdventurer commented 2 years ago

Describe the bug For several models, including 'beit_base_patch16_224', an exception is thrown when trying to load it with pretrained weights. The problem occurs on all machines and python versions I tried it on, including Google Colab.

To Reproduce import timm model = timm.create_model('beit_base_patch16_224', pretrained=True)

Gives the following stack trace

UnpicklingError Traceback (most recent call last) in () ----> 1 model = timm.create_model('beit_base_patch16_224', pretrained=True)

7 frames /usr/local/lib/python3.7/dist-packages/timm/models/factory.py in create_model(model_name, pretrained, checkpoint_path, scriptable, exportable, no_jit, kwargs) 72 73 with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): ---> 74 model = create_fn(pretrained=pretrained, kwargs) 75 76 if checkpoint_path:

/usr/local/lib/python3.7/dist-packages/timm/models/beit.py in beit_base_patch16_224(pretrained, kwargs) 359 patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 360 use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, kwargs) --> 361 model = _create_beit('beit_base_patch16_224', pretrained=pretrained, **model_kwargs) 362 return model 363

/usr/local/lib/python3.7/dist-packages/timm/models/beit.py in _create_beit(variant, pretrained, default_cfg, kwargs) 350 # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes 351 pretrained_filter_fn=checkpoint_filter_fn, --> 352 kwargs) 353 return model 354

/usr/local/lib/python3.7/dist-packages/timm/models/helpers.py in build_model_with_cfg(model_cls, variant, pretrained, default_cfg, model_cfg, feature_cfg, pretrained_strict, pretrained_filter_fn, pretrained_custom_load, kwargs_filter, **kwargs) 468 in_chans=kwargs.get('in_chans', 3), 469 filter_fn=pretrained_filter_fn, --> 470 strict=pretrained_strict) 471 472 # Wrap the model in a feature extraction module if enabled

/usr/local/lib/python3.7/dist-packages/timm/models/helpers.py in load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn, strict, progress) 187 if pretrained_url: 188 _logger.info(f'Loading pretrained weights from url ({pretrained_url})') --> 189 state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu') 190 elif hf_hub_id and has_hf_hub(necessary=True): 191 _logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')

/usr/local/lib/python3.7/dist-packages/torch/hub.py in load_state_dict_from_url(url, model_dir, map_location, progress, check_hash, file_name) 593 if _is_legacy_zip_format(cached_file): 594 return _legacy_zip_load(cached_file, model_dir, map_location) --> 595 return torch.load(cached_file, map_location=map_location)

/usr/local/lib/python3.7/dist-packages/torch/serialization.py in load(f, map_location, pickle_module, pickle_load_args) 711 return torch.jit.load(opened_file) 712 return _load(opened_zipfile, map_location, pickle_module, pickle_load_args) --> 713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) 714 715

/usr/local/lib/python3.7/dist-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, pickle_load_args) 918 "functionality.") 919 --> 920 magic_number = pickle_module.load(f, pickle_load_args) 921 if magic_number != MAGIC_NUMBER: 922 raise RuntimeError("Invalid magic number; corrupt file?")

UnpicklingError: invalid load key, '='.

Finally Kind regards!

rwightman commented 2 years ago

@IndoorAdventurer what version of timm? beit checkpoints changed location, the master is updated but I haven't backported to pypi 0.5.x release... I did do a 0.6.2.dev0 release that can be installed with pip install --pre or timm==0.6.2.dev0

IndoorAdventurer commented 2 years ago

Ah, I was using the default 0.5.4 version. I just tried 0.6.2 on Google Colab and it works! :-D (I did try 0.6.2 locally and got the same error, so I probably forgot to delete a wrong .pth file)

Thank you so much for the quick response!

Kind regards Vincent