huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
31.55k stars 4.71k forks source link

[BUG] `return_dict` option in feature extractor breaks torch.jit.script #2216

Closed dsuess closed 2 months ago

dsuess commented 2 months ago

Describe the bug

The recent addition of the return_dict option to FeatureGraphNet and GraphExtractNet breaks jit-scripting models with feature extraction.

To Reproduce Steps to reproduce the behavior:

import timm
import torch 

model = timm.create_model(
    model_name="regnetx_008",
    features_only=True,
    feature_cfg={"feature_cls": "fx"},
)

torch.jit.script(model)

breaks with the following error message:

revious return statement returned a value of type Dict[str, Tensor] but this return statement returns a value of type List[Tensor]:
  File "/opt/conda/envs/dev/lib/python3.11/site-packages/timm/models/_features_fx.py", line 141
        if self.return_dict:
            return out
        return list(out.values())
        ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

Expected behavior

Models with feature extraction turned on are still script-able.

Desktop (please complete the following information):

Additional context

A simple fix would be to have different classes for return_dict=True and False, which only have a single, fixed return type.

rwightman commented 2 months ago

Hmm, I guess I forgot to add the final annotation on that return_dict bool, that should fix it without needing different classes. I'm away from computer for a few days, you can try that and PR welcome, otherwise will do this weekend. You can find other examples, eg use_fused_attn https://github.com/huggingface/pytorch-image-models/blob/d4ef0b4d589c9b0cb1d6240ff373c5508dbb8023/timm/layers/attention2d.py#L93

dsuess commented 2 months ago

Thanks for the quick reply! I didn't know about Final. Let me create a PR.