lukemelas / PyTorch-Pretrained-ViT

Vision Transformer (ViT) in PyTorch
770 stars 124 forks source link

General changes and adding of support for more functionality #7

Open arkel23 opened 3 years ago

arkel23 commented 3 years ago

Added support for 'H-14' and L'16' ViT models. Added support for downloading the models directly from Google's cloud storage. Corrected the Jax to Pytorch weights transformation. Previous methodology would lead to .pth state_dict files without the 'representation layer'. ViT('load_repr_layer'=True) would lead to an error. If only interested in inference the representation layer was unnecessary as discussed in the original paper for the Vision Transformer, but for other applications and experiments it may be useful so I added a download_convert_models.py to first download the required models, convert them with all the weights, and then you can completely tune the parameters. Added support for visualizing attention, by returning the scores values in the multi-head self-attention layers. The visualizing script was mostly taken from jeonsworld/ViT-pytorch repository. Added examples for inference (single image), and fine-tuning/training (using CIFAR-10).

lukemelas commented 3 years ago

Thank you! This PR looks brilliant. I am excited to review it and merge it -- it might take a bit longer than usual due to the holidays, but I'll get to it soon.

On Tue, Dec 22, 2020 at 7:11 AM Edwin Arkel Rios notifications@github.com wrote:

Added support for 'H-14' and L'16' ViT models. Added support for downloading the models directly from Google's cloud storage. Corrected the Jax to Pytorch weights transformation. Previous methodology would lead to .pth state_dict files without the 'representation layer'. ViT('load_repr_layer'=True) would lead to an error. If only interested in inference the representation layer was unnecessary as discussed in the original paper for the Vision Transformer, but for other applications and experiments it may be useful so I added a download_convert_models.py to first download the required models, convert them with all the weights, and then you can completely tune the parameters. Added support for visualizing attention, by returning the scores values in the multi-head self-attention layers. The visualizing script was mostly taken from jeonsworld/ViT-pytorch repository. Added examples for inference (single image), and fine-tuning/training (using CIFAR-10).

You can view, comment on, or merge this pull request online at:

https://github.com/lukemelas/PyTorch-Pretrained-ViT/pull/7 Commit Summary

  • changed convert.py, added explore-conversion_21k.py script, added logs for the conversion, added download links to download.sh and configs.py for models that were missing
  • restructured directory and made it so that instead of downloading pth beforehand it directly downloads them to torchhub and then converts them on the fly
  • restructured, deleted jax_to_pytorch and moved to utils.py and made sure that it loads the representation layer
  • deleted jax_to_pytorch and added the py to download the models
  • deleted jax_to_pytorch and combined relevant files into pytorc_pretrained_vit into utils.py
  • added some inference scripts and some annotations in transformer.py
  • added an example for cifar-10 dataset
  • added files and example to visualize attention, modified transformer to return head scores if given parameter visualize=True is given, otherwise functionality stays the same
  • changes to allow for visualization and compatibility with torchsummary, also added an example with cifar-10. changed the loading logic to allow for appropriate loading of all layers regardless of if loading fc layers with different number of classes and/or representation layer. also verified that they load properly
  • Update README.md
  • Update README.md
  • Update README.md
  • Update README.md

File Changes

Patch Links:

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/lukemelas/PyTorch-Pretrained-ViT/pull/7, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADFQ4MDLVFJHAZIV3FAXD73SWCEFTANCNFSM4VFRWC5A .

huananerban commented 2 months ago

@arkel23 I would like to ask you why my L-16 pre-trained model still can't be trained, I get an error"Missing keys when loading pretrained weights: []"