LeapLabTHU / MLLA

Official repository of MLLA
127 stars 4 forks source link

Demystify Mamba in Vision: A Linear Attention Perspective

This repo contains the official PyTorch code and pre-trained models for Mamba-Like Linear Attention (MLLA).

Abstract

Mamba is an effective state space model with linear computation complexity. It has recently shown impressive efficiency in dealing with high-resolution inputs across various vision tasks. In this paper, we reveal that the powerful Mamba model shares surprising similarities with linear attention Transformer, which typically underperform conventional Transformer in practice. By exploring the similarities and disparities between the effective Mamba and subpar linear attention Transformer, we provide comprehensive analyses to demystify the key factors behind Mamba's success. Specifically, we reformulate the selective state space model and linear attention within a unified formulation, rephrasing Mamba as a variant of linear attention Transformer with six major distinctions: input gate, forget gate, shortcut, no attention normalization, single-head, and modified block design. For each design, we meticulously analyze its pros and cons, and empirically evaluate its impact on model performance in vision tasks. Interestingly, the results highlight the forget gate and block design as the core contributors to Mamba's success, while the other four designs are less crucial. Based on these findings, we propose a Mamba-Like Linear Attention (MLLA) model by incorporating the merits of these two key designs into linear attention. The resulting model outperforms various vision Mamba models in both image classification and high-resolution dense prediction tasks, while enjoying parallelizable computation and fast inference speed.

Connecting Mamba and Linear Attention Transformer

This paper reveals Mamba's close relationship to linear attention Transformer: Mamba and linear attention Transformer can be formulated within a unified framework, with Mamba exhibiting six distinct designs compared to the conventional linear attention paradigm: input gate, forget gate, shortcut, no attention normalization, single-head and modified block design.

Results

Dependencies

The ImageNet dataset should be prepared as follows:

imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img2.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img3.jpeg
    │   └── ...
    ├── class2
    │   ├── img4.jpeg
    │   └── ...
    └── ...

Pretrained Models

model Resolution #Params FLOPs acc@1 config pretrained weights
MLLA-T 224 25M 4.2G 83.5 config TsinghuaCloud
MLLA-S 224 43M 7.3G 84.4 config TsinghuaCloud
MLLA-B 224 96M 16.2G 85.3 config TsinghuaCloud

Model Training and Inference

python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --eval --resume <path-to-pretrained-weights>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --amp

Acknowledgements

This code is developed on the top of Swin Transformer.

Citation

If you find this repo helpful, please consider citing us.

@article{han2024demystify,
  title={Demystify Mamba in Vision: A Linear Attention Perspective},
  author={Han, Dongchen and Wang, Ziyi and Xia, Zhuofan and Han, Yizeng and Pu, Yifan and Ge, Chunjiang and Song, Jun and Song, Shiji and Zheng, Bo and Huang, Gao},
  journal={arXiv preprint arXiv:2405.16605},
  year={2024}
}

Contact

If you have any questions, please feel free to contact the authors.

Dongchen Han: hdc23@mails.tsinghua.edu.cn