sustcsonglin / flash-linear-attention

Efficient implementations of state-of-the-art linear attention models in Pytorch and Triton
MIT License
1.24k stars 66 forks source link

Add implementations of Mamba 2 into FLA #39

Closed DanFosing closed 1 month ago

DanFosing commented 1 month ago

Modified version of https://github.com/huggingface/transformers/tree/add_codestral_mamba2/src/transformers/models/mamba2 (https://github.com/huggingface/transformers/pull/32080) to work with FLA. I haven't tested if it works yet, but I'm pretty sure it will work. It can probably be made a bit faster by implementing gated RMSNorm utilizing SiLU.

yzhangcs commented 1 month ago

@DanFosing Great job! I'll make some tests soon. Thank you for the contributions.

yzhangcs commented 1 month ago

Could you add __init__.py file like https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/models/mamba/__init__.py to make the model recongnizable by fla and register it in https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/models/__init__.py

DanFosing commented 1 month ago

I think it should work with fla now.

DanFosing commented 1 month ago

Unfortunately because it's still an ongoing pull request into transformers package, there is a possibility that it may not fully work in some specific cases, that's why it requires some testing.

yzhangcs commented 1 month ago

@DanFosing Hi, looks like there are still some errors

  File "flash-linear-attention/fla/models/mamba2/modeling_mamba2.py", line 607, in forward
    return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "flash-linear-attention/fla/models/mamba2/modeling_mamba2.py", line 333, in cuda_kernels_forward
    rmsnorm_weight=self.norm.weight,
                   ^^^^^^^^^^^^^^^^
  File "anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1709, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'GatedRMSNorm' object has no attribute 'weight'
yzhangcs commented 1 month ago

@DanFosing You can check your impls via running

python -m benchmarks.benchmark_training_throughput --name mamba2
yzhangcs commented 1 month ago

Also it would be better to beautify your code style via pre-commit to make it follow PEP8 guidelines.

yzhangcs commented 1 month ago

You can refer to https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/layers/gla.py#L219-L228 for our impls of gate norm. I think there is no need to further wrap it with another GatedNorm.

yzhangcs commented 1 month ago

@DanFosing Appreciate your hard work and quick response.

DanFosing commented 1 month ago

Are you sure it works? When I tried it on kaggle I got assertion errors but it often has some weird problems so it may be an issue with kaggle (I can't test it on my pc right now)

yzhangcs commented 1 month ago

Could you paste more detailed infos?

yzhangcs commented 1 month ago

@DanFosing Here is the output

$ python -m benchmarks.benchmark_training_throughput --name mamba2
Initializing mamba2 model from the config:
Mamba2Config {
  "bos_token_id": 1,
  "chunk_size": 256,
  "conv_kernel": 4,
  "eos_token_id": 2,
  "expand": 2,
  "fuse_cross_entropy": true,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.1,
  "layer_norm_epsilon": 1e-05,
  "model_type": "mamba2",
  "n_groups": 8,
  "norm_before_gate": true,
  "num_heads": 64,
  "num_hidden_layers": 48,
  "pad_token_id": 0,
  "rescale_prenorm_residual": false,
  "residual_in_fp32": true,
  "rms_norm": true,
  "state_size": 128,
  "tie_word_embeddings": false,
  "time_step_floor": 0.0001,
  "time_step_limit": [
    0.0,
    Infinity
  ],
  "time_step_max": 0.1,
  "time_step_min": 0.001,
  "time_step_rank": 128,
  "transformers_version": "4.43.3",
  "use_bias": false,
  "use_cache": true,
  "use_conv_bias": true,
  "vocab_size": 32000
}

Mamba2ForCausalLM(
  (backbone): Mamba2Model(
    (embeddings): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-47): 48 x Mamba2Block(
        (norm): RMSNorm(2048, eps=1e-05)
        (mixer): Mamba2Mixer(
          (act): SiLU()
          (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)
          (in_proj): Linear(in_features=2048, out_features=10304, bias=False)
          (norm): FusedRMSNormSwishGate(4096, eps=1e-05)
          (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
        )
      )
    )
    (norm_f): RMSNorm(2048, eps=1e-05)
  )
  (lm_head): Linear(in_features=2048, out_features=32000, bias=False)
)
Number of parameters in total: 1548430336 (1.44GiB)
Allocated memory after initialization: 2.89GiB
Max memory allocated: 42.01GiB: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:11<00:00,  4.44s/it]
Thoughput:   32048.92 tokens/s: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:16<00:00,  1.96it/s]
DanFosing commented 1 month ago

Turns out it was just a problem with T4 not supporting some things from causal_conv1d package, there is the same issue with mamba 1 and 2 from mamba-ssm. I don't remember what it was but there was some workaround that fixed it if you turned off some low memory mode in pytorch or something. Btw torch.compile support was added to mamba-1 but it requires some changes in mambacache and mamba modeling, I will try to implement it when I have some time.

learning-chip commented 1 month ago

One question on this PR:

From the FLA paper, both mamba-1 and mamba-2 can be written as Gated linear attention formulation:

fla_table

So, the SSM part of Mamba-2 (excluding conv1d, normalization, ...) should permit the same "Chunkwise Parallel Form" as advocated by the FLA paper, no?

chunkwise

Then, the modeling_mamba2.py implementation here shouldn't need to import mamba_chunk_scan_combined from original mamba_ssm repository, right? It can use a custom kernel similar to ops/gla/chunk_fuse.py in this repo.

Otherwise, the modeling_mamba2.py in current PR only uses custom RMSNorm/FusedRMSNormSwishGate kernels, while the SSM part is no different from original mamba_ssm repository. Then the performance will remain largely the same as original repo, and you cannot tell whether FLA formulation is faster...

learning-chip commented 1 month ago

The mamba-2 blog did mention that FLA chunkwise parallel "turns out to be essentially equivalent to the SSD algorithm specialized to a restricted case"

special_case

Will the FLA formulation be just identical to the mamba_chunk_scan_combined in original mamba-2 code? Or there is still some chance to improve on the original ver?

DanFosing commented 1 month ago

Indeed I think it can be made faster if a kernel similar to chunk_fuse.py is used (maybe it would be possible to just modify gla one as both GLA and Mamba-2 are extremely similar to each other). Unfortunately I'm not familar with triton so I can't really do it, I think it would be best if you made an issue or something, unless some of FLA devs replies there. Current implementation is a lot like mamba-1 implementation, both are mostly like original mamba-ssm ones, with just some minor speed up thanks to FusedCrossEntropyLoss and custom RMSNorm and those are compatible with huggingface transformers.

DanFosing commented 1 month ago

If you take a look at this paper: https://arxiv.org/pdf/2406.06484 you can also see that in terms of recurrence and memory read-out mamba-2 is very similar to RetNet and Linear Attention: image

learning-chip commented 1 month ago

I think it would be best if you made an issue or something, unless some of FLA devs replies there.

Will look into it!

Also @yzhangcs for any suggestions -- like, which existing triton kernel is the best starting point to re-implement the SSM part of mamba2? From the table above, kernels for RetNet should be the most close ones?

The mamba-2 blog also mentions the close connection between RetNet and Mamba-2 (under their SSD framework):

Prior examples include the original linear attention as well as the recent Retentive Network (RetNet) model [18] . These can be viewed as direct special cases of SSD.

yzhangcs commented 1 month ago

@learning-chip Hi, you may refer to simple gla by @sustcsonglin which provides data-dependent decay upon RetNet.