huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
32.26k stars 4.76k forks source link

[FEATURE] Add DeiT models #349

Closed Yuxin-CV closed 3 years ago

Yuxin-CV commented 3 years ago

Is your feature request related to a problem? Please describe.

DeiT (https://github.com/facebookresearch/deit) heavily depends on timm, and the loader implemented by the original DeiT is slow compared with the loader in timm.

Hope DeiT models can be added in timm in the near future.

rwightman commented 3 years ago

@Yuxin-CV they just updated their license so that doing what you ask is possible, just have some other things to work on first. The deit models as you know are using the timm vit so it's easy to cut & paste model defs and weight url into vision_transformer.py here.

Supporting the distillation training, etc would be more work that I'm not in a hurry to do. I have broader idea of supporting more advance training techniques for timm models (distillation, byol, simclr, etc) but haven't figured out what shape that'd take.

EDIT: One thing to keep in mind, aside from the tiny/small models that DEIT repo has model defs and weights for. I don't see there being a benefit to say using the ImageNet-1K trained DEIT weights for the base sized model over the raw ImageNet-21k weights or fined tuned from ImageNet-21k -> 1k weights that I converted from the official JAX repo. The larger dataset weights should remain a better starting point.

Yuxin-CV commented 3 years ago

@Yuxin-CV they just updated their license so that doing what you ask is possible, just have some other things to work on first. The deit models as you know are using the timm vit so it's easy to cut & paste model defs and weight url into vision_transformer.py here.

Supporting the distillation training, etc would be more work that I'm not in a hurry to do. I have broader idea of supporting more advance training techniques for timm models (distillation, byol, simclr, etc) but haven't figured out what shape that'd take.

EDIT: One thing to keep in mind, aside from the tiny/small models that DEIT repo has model defs and weights for. I don't see there being a benefit to say using the ImageNet-1K trained DEIT weights for the base sized model over the raw ImageNet-21k weights or fined tuned from ImageNet-21k -> 1k weights that I converted from the official JAX repo. The larger dataset weights should remain a better starting point.

Thanks for your reply.