state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.07k stars 1.11k forks source link

Slow Mamba 2 training speeds with higher d_state values #479

Open jacob-morrison opened 3 months ago

jacob-morrison commented 3 months ago

Hi! I'm training a small Mamba2 model (~60m non-embedding parameters to start), and I'm doing some benchmarks before committing to larger runs. Any ideas why I'm seeing slow speeds with higher d_state values? I'm importing the Mamba2 blocks directly from this repo, and I've tested Mamba2 vs Mamba2Simple and haven't noticed any differences there.

With different d_state values, I get: d_state=128: 65k tokens/device/second (58m non-embedding parameters total) d_state=64: 100k tokens/device/second (56m non-embedding parameters total) d_state=32: 150k tokens/device/second (53.7m non-embedding parameters total) d_state=16: 195k tokens/device/second (52.9m non-embedding parameters total)

All experiments are running on H100s using DDP, and I'm using these other parameters:

n_layers: 16 vocab_size: 50280 embedding_size: 50304 d_conv: 4 expand: 2 ngroups: 1 (have also tried 8) headdim: 64

In comparison, we get ~350k tokens/second/device with a 60m transformer model with the same codebase.

YaoMufeng commented 3 months ago

I guess the higher d_state, the larger the model parameters. Happy to see that you successfully runs the Mamba2 I meet many errors when installing mamba2. Would you mind tell me your torch version, triton version, and causal_conv1d version?

zh7117 commented 3 months ago

I guess the higher d_state, the larger the model parameters. Happy to see that you successfully runs the Mamba2 I meet many errors when installing mamba2. Would you mind tell me your torch version, triton version, and causal_conv1d version?

I failed to build mamba-ssm from souce. But successully installed mamba-ssm and casual_conv1d by the wheel and run mamba2 successully. cuda version: 11.8 torch version: 2.1.2; install command: pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118 mamba wheel: https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl causal_conv1d wheel: https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl (you should select the wheel for your python version.) triton version: 2.1.0

Hope this is helpful to you.