Official PyTorch implementation of RepViT-SAM and RepViT. CVPR 2024.
Models are deployed on iPhone 12 with Core ML Tools to get latency.
Models are trained on ImageNet-1K and deployed on iPhone 12 with Core ML Tools to get latency.
RepViT-SAM: Towards Real-Time Segmenting Anything.\
Ao Wang, Hui Chen, Zijia Lin, Jungong Han, and Guiguang Ding\
[arXiv
] [Project Page
]
RepViT: Revisiting Mobile CNN From ViT Perspective.\
Ao Wang, Hui Chen, Zijia Lin, Jungong Han, and Guiguang Ding\
[arXiv
]
UPDATES 🔥
Model | Top-1 (300 / 450) | #params | MACs | Latency | Ckpt | Core ML | Log |
---|---|---|---|---|---|---|---|
M0.9 | 78.7 / 79.1 | 5.1M | 0.8G | 0.9ms | 300e / 450e | 300e / 450e | 300e / 450e |
M1.0 | 80.0 / 80.3 | 6.8M | 1.1G | 1.0ms | 300e / 450e | 300e / 450e | 300e / 450e |
M1.1 | 80.7 / 81.2 | 8.2M | 1.3G | 1.1ms | 300e / 450e | 300e / 450e | 300e / 450e |
M1.5 | 82.3 / 82.5 | 14.0M | 2.3G | 1.5ms | 300e / 450e | 300e / 450e | 300e / 450e |
M2.3 | 83.3 / 83.7 | 22.9M | 4.5G | 2.3ms | 300e / 450e | 300e / 450e | 300e / 450e |
Tips: Convert a training-time RepViT into the inference-time structure
from timm.models import create_model
import utils
model = create_model('repvit_m0_9')
utils.replace_batchnorm(model)
The latency reported in RepViT for iPhone 12 (iOS 16) uses the benchmark tool from XCode 14. For example, here is a latency measurement of RepViT-M0.9:
Tips: export the model to Core ML model
python export_coreml.py --model repvit_m0_9 --ckpt pretrain/repvit_m0_9_distill_300e.pth
Tips: measure the throughput on GPU
python speed_gpu.py --model repvit_m0_9
conda
virtual environment is recommended.
conda create -n repvit python=3.8
pip install -r requirements.txt
Download and extract ImageNet train and val images from http://image-net.org/. The training and validation data are expected to be in the train
folder and val
folder respectively:
|-- /path/to/imagenet/
|-- train
|-- val
To train RepViT-M0.9 on an 8-GPU machine:
python -m torch.distributed.launch --nproc_per_node=8 --master_port 12346 --use_env main.py --model repvit_m0_9 --data-path ~/imagenet --dist-eval
Tips: specify your data path and model name!
For example, to test RepViT-M0.9:
python main.py --eval --model repvit_m0_9 --resume pretrain/repvit_m0_9_distill_300e.pth --data-path ~/imagenet
Object Detection and Instance Segmentation
Semantic Segmentation
Classification (ImageNet) code base is partly built with LeViT, PoolFormer and EfficientFormer.
The detection and segmentation pipeline is from MMCV (MMDetection and MMSegmentation).
Thanks for the great implementations!
If our code or models help your work, please cite our paper:
@inproceedings{wang2024repvit,
title={Repvit: Revisiting mobile cnn from vit perspective},
author={Wang, Ao and Chen, Hui and Lin, Zijia and Han, Jungong and Ding, Guiguang},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={15909--15920},
year={2024}
}
@misc{wang2023repvitsam,
title={RepViT-SAM: Towards Real-Time Segmenting Anything},
author={Ao Wang and Hui Chen and Zijia Lin and Jungong Han and Guiguang Ding},
year={2023},
eprint={2312.05760},
archivePrefix={arXiv},
primaryClass={cs.CV}
}