muhd-umer / pvt-flax

Unofficial JAX/Flax implementation of Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions.
https://arxiv.org/abs/2102.12122
MIT License
2 stars 0 forks source link
deep-learning flax image-classification implementation jax

Pyramid Vision Transformer

License JAX

This repo contains the unofficial JAX/Flax implementation of PVT v2: Improved Baselines with Pyramid Vision Transformer.
All credits to the authors Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao for their wonderful work.

Dependencies

It is recommended to create a new virtual environment so that updates/downgrades of packages do not break other projects.

Note: Flax is not dependent on TensorFlow itself, however, we make use of methods that take advantage of tf.io.gfile As such, we only install tensorflow-cpu. Same is the case with PyTorch, we only install it in order to use their torch.data.DataLoader.

Run

To get started, clone this repo and install the required dependencies.

Datasets

Training

Evaluation

Note: Since my undergrad studies are resuming after summer break, I may or may not be able to find time to complete the above tasks. If you want to implement the aforelisted tasks, I'll be more than glad to merge your pull request. ❤️

Acknowledgements

We acknowledge the excellent implementation of PVT in MMDetection, PyTorch Image Models and the official implementation. I referred to these implementations as a source of reference.

Citing PVT