huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
32.34k stars 4.76k forks source link

[BUG] Symbolic tracing not working for few models #1411

Closed soumendukrg closed 1 year ago

soumendukrg commented 2 years ago

Describe the bug I am working on quantization of few timm models using Torch FX Graph Mode Quantization. Specifically, I am looking into post training static quantization. For static models like ResNet, ResNext, Deit, following PyTorch quantization tutorial, I am able to quantize the model. However for dynamic models like MobileNetV3, MobileVit, EfficientNetV2, I am not able to do so as these models are not symbolically traceable. As per my understanding, when I set the 'exportable' flag to True during the call to function timm.create_model(...), the model should be traceable end to end. Please clarify.

Tracing Error

Please ignore the commented line # return torch.max(....)

I have tried some workarounds like wrapping the non-traceable parts of the code into a separate module and then specifying the qconfig_dict in quantization steps accordingly, but still face the same issue.

Note: Calling torch.jit.trace(model, sample_input) does execute with some warnings, but no errors.

jit trace warnings

Please share any advice on how to make these models all traceable.

Desktop (please complete the following information):

soumendukrg commented 2 years ago

@rwightman Please let me know if you need any other information or log related to this?

Zagreus98 commented 1 year ago

Hi @rwightman , is there any workaround for this issue?

Zagreus98 commented 1 year ago

For anyone having the same problem the workaround was to use models which were not developed after the tensorflow counter part. So, for example changing the 'tf_efficientnet_b3' with 'efficientnet_b3' which does not have dynamic control flow solved the problem. For torch fx limitations: https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing

rwightman commented 1 year ago

this should be improved due to recent changes, always better to use the non tf_ prefix models that don't have the dynamic padding (when possible)