Open wangyao0512 opened 2 weeks ago
VMamba-T is slower in training compared to Swin-T, but 5 times is really beyond my imagination, Actually, it is less than 2 times in training.
In the appendix in arxiv paper v3, it says:
Also, I immediately reproduced my experiments in a node with 8 x A800 with a CPU of Intel Xeon Platinum 8358, the results show that VMamba-T with maskrcnn
may take 11 hours
to train, while Swin-T only needs 7 hours
, which is consistent with the arxiv paper.
By the way, if you are using VMamba-T and are willing to use the pretrained checkpoint in classification, you need to load the pretrained model in extra (as the config parameter model.backbone.pretrained
leaves blank in VMamba, yet in Swin, it is filled with a URL redirecting to the checkpoint storage.)
As I've emphasized, you should be careful about the checkpoint loading, and that may be the key to the performance in downstream tasks.
Finaly, thank you for your attention and feel free to raise an issue here, and I'll try my best to answer if I am available.
Thank you very much for your reply. May I know your detailed configuration?
If it is in the premise of successful operation, the running speed is slow in addition to the problem of environment configuration, what are the possible problems?
My environment is: Cuda11.8 + cudnn8700 python 3.10.0 causal-conv1d 1.2.1 mamba-ssm 2.0.3 Torch2.1.0 + cu118 triton 2.1.0
We did not use mamba_ssm
and causal-conv1d
in our code, instead, we modified it to kernels/selective_scan
, which is supposed to be installed.
Another part that influences the speed is the cross_scan
and cross_merge
. If you are using the old version of this file, update the code. You can also check the code in csm_triton.py#CrossScanTriton
and csm_triton.py#CrossScanTriton
(in the latest version, those two classes have been replaced by a more general version called cross_scan_fn
and cross_merge_fn
), and if you find the line BC, BH, BW = min(triton.next_power_of_2(C), 1), min(triton.next_power_of_2(H), 64), min(triton.next_power_of_2(W), 64)
in them, modify it to BC, BH, BW = 1, 32, 32
or BC, BH, BW = 1, 64, 64
to avoid re-compilation in dealing with images of different resolutions.
.
VMamba works fine for me, however, I am confused about this table. Why is the training throughput for some of these so much higher than the inference throughput? For example, DeiT is 2x faster training than running inference. This makes no sense and some clarification would be awesome. Thanks for the great work.
The train throughput
seems bigger, as it was tested with the context of mix-precision (i.e. torch.amp). While on the other hand, the throughput
is tested purely in the data format of float32.
Ahh I see, thanks!
I have been running the Swin_Transformer and VMamba models on the same A800 GPU, using same batch sizes and the COCO2017 detection dataset. However, I've observed that VMamba performs at least 5 times slower than the Swin_Transformer, on the same training stage. VMamba configs: mask_rcnn_coco_fpn_base.py Swin_Transformer configs: mask_rcnn_swin-t-p4-w7-1x-coco_base.py
Additionally, when testing on my own dataset, VMamba's results were inferior to those of the vanilla vanilla Transformer. What could be the possible reasons? Is VMamba model real and valid?