EasonXiao-888 / GrootVL

The official implementation of GrootVL: Tree Topology is All You Need in State Space Model
58 stars 2 forks source link

🌲 GrootVL: Tree Topology is All You Need in State Space Model

Yicheng Xiao1,*, Lin Song2,3,📧,*, Shaoli Huang3, Jiangshan Wang1, [Siyu Song4](), Yixiao Ge2,3, Xiu Li1,📧 and Ying Shan2,3

* Equal contribution 📧 Corresponding author

1 Tsinghua University 2 ARC Lab, Tencent PCG 3 Tencent AI Lab 4 South China Normal University

📖 Abstract

The state space models, employing recursively propagated features, demonstrate strong representation capabilities comparable to Transformer models and superior efficiency. However, constrained by the inherent geometric constraints of sequences, it still falls short in modeling long-range dependencies. To address this issue, we propose the GrootVL network, which first dynamically generates a tree topology based on spatial relationships and input features. Then, feature propagation is performed based on this graph, thereby breaking the original sequence constraints to achieve stronger representation capabilities. Additionally, we introduce a linear complexity dynamic programming algorithm to enhance long-range interactions without increasing computational cost. GrootVL is a versatile multimodal framework that can be applied to both visual and textual tasks. Extensive experiments demonstrate that our method significantly outperforms existing structured state space models on image classification, object detection and segmentation. Besides, by fine-tuning large language models, our approach achieves consistent improvements in multiple textual tasks at minor training cost.

⚓ Tree State Space Model with Tree Scanning Algorithm

We first revisit the selective state space model and design an input-aware topology scanning algorithm for state space modeling. With this superior algorithm, we develop a tree SSM and propose a novel framework called GrootVL, which consists of two sub-networks: GrootV for visual tasks and GrootL for fine-tuning a pre-trained language model.

⛲ Efficient Implementation for Multi-Modality

We utilize a dynamic programming procedure to accelerate the inference and training processes which results in linear complexity $O(L)$ instead of $O(L^2)$.


🛠️ Environment Setup

Vision Tasks

conda create -n grootv python=3.9
conda activate grootv

# Install pytorch 
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117

# Install other packages
pip install -r GrootV/grootv_requirements.txt 

# Install Vision_Tree_Scanning
cd GrootV/third-party/TreeScan
pip install -v -e .

Language Tasks

conda create -n grootl python=3.9
conda activate grootl

# Install pytorch 
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2

# Install other packages
pip install -r GrootL/grootl_requirements.txt 

# Install Language_Tree_Scanning
cd GrootL/third-party/TreeScanLan
pip install -v -e .

# Install language model evluation tools
cd GrootL/3rdparty/lm-evaluation-harness
pip install -v -e .

🍺 Model Zoo

Vision Tasks

ImageNet-1k Image Classification
| name | pretrain | resolution | acc@1 | #param | FLOPs | download | | :------------: | :----------: | :--------: | :---: | :----: | :---: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------: | | GrootV-T | ImageNet-1K | 224x224 | 83.4 | 30M | 4.8G | [ckpt](https://drive.google.com/file/d/1OIiMBxk92WhPssRg0pv0U5y8ZBtHq2eQ/view?usp=drive_link) \| [cfg](GrootV/classification/config/grootv_t_1k_224.yaml) | | GrootV-S | ImageNet-1K | 224x224 | 84.2 | 51M | 8.5G | [ckpt](https://drive.google.com/file/d/1G6DdEI3JDSltfbmZmcqz79rxM1t01K_8/view?usp=drive_link) \| [cfg](GrootV/classification/config/grootv_s_1k_224.yaml) | | GrootV-B | ImageNet-1K | 224x224 | 84.8 | 91M | 15.1G | [ckpt](https://drive.google.com/file/d/1-8rwMVinj_fV9YMlzx6fGwdjxyLcjd3w/view?usp=drive_link) \| [cfg](GrootV/classification/config/grootv_b_1k_224.yaml) | | GrootV-L | ImageNet-22K | 384x384 | RUNNING | - | - | [ckpt]() \| [cfg]() |
COCO Object Detection and Instance Segmentation
| backbone | method | schedule | box mAP | mask mAP | #param | FLOPs | download | | :------------: | :--------: | :---: | :-----: | :------: | :----: | :---: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | GrootV-T | Mask R-CNN | 1x | 47.0 | 42.7 | 49M | 265G | - \| [cfg]() | | GrootV-T | Mask R-CNN | 3x | 49.0 | 43.8 | 49M | 265G | - \| [cfg]() | | GrootV-S | Mask R-CNN | 1x | 48.6 | 43.6 | 70M | 341G | - \| [cfg]() | | GrootV-S | Mask R-CNN | 3x | 50.1 | 44.6 | 70M | 341G | - \| [cfg]() |
ADE20K Semantic Segmentation
| backbone | method | resolution | mIoU (ss/ms) | #param | FLOPs | download | | :------------: | :---------: | :--------: | :----------: | :----: | :---: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | GrootV-T | UperNet | 512x512 | 48.5 / 49.4 | 60M | 941G | - \| [cfg]() | | GrootV-S | UperNet | 512x512 | 50.7 / 51.7 | 82M | 1019G | - \| [cfg]() |

Language Tasks

Language Understanding
| Method | PIQA ↑ | Arc-E ↑ | sst ↑ | WinoGrande ↑ | LAMBADA-ppl ↓ | race ↑ | Openbookqa ↑ | Average Acc ↑ | download | | :------------: | :----------: | :--------: | :---: | :----: | :---: | :---: | :---: | :---: |:---------------------------------------------------------------------------------------------------------------------------------------------------------------: | | Mamba | 64.5 | 48.0 | 65.6 | 51.8 | 16.1 | 27.4 | 16.8 | 45/7 | [model](https://huggingface.co/state-spaces/mamba-130m-hf) | | +LoRA | 64.7 | 48.3 | 65.1 | 52.2 | 17.7 | 28.6 | 17.8 | 46.1| - | | +GrootL | 65.0 | 49.8 | 69.5 | 51.1 | 15.9 | 28.9 | 19.2 | 47.2| [model](https://drive.google.com/file/d/1oby7sHYUxg4TIqFjIXSk8GB98fa9kSBm/view?usp=drive_link) |

🚀 Train & Evaluate

ImageNet-1k Image Classification
`bash GrootV/scripts/bash_cls_train.sh` You need to modify the relevant path to your own.
Language Understanding
cd GrootL bash eval.sh You need to modify the relevant path to your own.

⭐️ BibTeX

If you find this work useful for your research, please cite:

@article{xiao2024grootvl,
  title={GrootVL: Tree Topology is All You Need in State Space Model},
  author={Xiao, Yicheng and Song, Lin and Huang, Shaoli and Wang, Jiangshan and Song, Siyu and Ge, Yixiao and Li, Xiu and Shan, Ying},
  journal={arXiv preprint arXiv:2406.02395},
  year={2024}
}

❤️ Acknowledgement

Code in this repository is built upon several public repositories. Thanks for the wonderful work InternImage and VMamba ! !

☑️ LICENSE

Our codes are under MIT license.