MzeroMiko / VMamba

VMamba: Visual State Space Models,code is based on mamba
MIT License
1.82k stars 98 forks source link

VMamba is so slower than transformer, is the model valid? #234

Open wangyao0512 opened 2 weeks ago

wangyao0512 commented 2 weeks ago

I have been running the Swin_Transformer and VMamba models on the same A800 GPU, using same batch sizes and the COCO2017 detection dataset. However, I've observed that VMamba performs at least 5 times slower than the Swin_Transformer, on the same training stage. VMamba configs: mask_rcnn_coco_fpn_base.py Swin_Transformer configs: mask_rcnn_swin-t-p4-w7-1x-coco_base.py

微信图片_20240615185604

Additionally, when testing on my own dataset, VMamba's results were inferior to those of the vanilla vanilla Transformer. What could be the possible reasons? Is VMamba model real and valid?

MzeroMiko commented 2 weeks ago

About the training speed

VMamba-T is slower in training compared to Swin-T, but 5 times is really beyond my imagination, Actually, it is less than 2 times in training. In the appendix in arxiv paper v3, it says: image

Also, I immediately reproduced my experiments in a node with 8 x A800 with a CPU of Intel Xeon Platinum 8358, the results show that VMamba-T with maskrcnn may take 11 hours to train, while Swin-T only needs 7 hours, which is consistent with the arxiv paper.
By the way, if you are using VMamba-T and are willing to use the pretrained checkpoint in classification, you need to load the pretrained model in extra (as the config parameter model.backbone.pretrained leaves blank in VMamba, yet in Swin, it is filled with a URL redirecting to the checkpoint storage.)

image image

image image

About the performance

As I've emphasized, you should be careful about the checkpoint loading, and that may be the key to the performance in downstream tasks.

Finaly, thank you for your attention and feel free to raise an issue here, and I'll try my best to answer if I am available.

wangyao0512 commented 2 weeks ago

Thank you very much for your reply. May I know your detailed configuration?

If it is in the premise of successful operation, the running speed is slow in addition to the problem of environment configuration, what are the possible problems?

My environment is: Cuda11.8 + cudnn8700 python 3.10.0 causal-conv1d 1.2.1 mamba-ssm 2.0.3 Torch2.1.0 + cu118 triton 2.1.0

MzeroMiko commented 2 weeks ago
  1. We did not use mamba_ssm and causal-conv1d in our code, instead, we modified it to kernels/selective_scan, which is supposed to be installed.

  2. Another part that influences the speed is the cross_scan and cross_merge. If you are using the old version of this file, update the code. You can also check the code in csm_triton.py#CrossScanTriton and csm_triton.py#CrossScanTriton (in the latest version, those two classes have been replaced by a more general version called cross_scan_fn and cross_merge_fn), and if you find the line BC, BH, BW = min(triton.next_power_of_2(C), 1), min(triton.next_power_of_2(H), 64), min(triton.next_power_of_2(W), 64) in them, modify it to BC, BH, BW = 1, 32, 32 or BC, BH, BW = 1, 64, 64 to avoid re-compilation in dealing with images of different resolutions. .

Hprairie commented 3 days ago

VMamba works fine for me, however, I am confused about this table. Why is the training throughput for some of these so much higher than the inference throughput? For example, DeiT is 2x faster training than running inference. This makes no sense and some clarification would be awesome. Thanks for the great work.

MzeroMiko commented 2 days ago

The train throughput seems bigger, as it was tested with the context of mix-precision (i.e. torch.amp). While on the other hand, the throughput is tested purely in the data format of float32.

Hprairie commented 2 days ago

Ahh I see, thanks!