Closed DanFosing closed 1 month ago
@DanFosing Great job! I'll make some tests soon. Thank you for the contributions.
Could you add __init__.py
file like https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/models/mamba/__init__.py to make the model recongnizable by fla and register it in
https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/models/__init__.py
I think it should work with fla now.
Unfortunately because it's still an ongoing pull request into transformers package, there is a possibility that it may not fully work in some specific cases, that's why it requires some testing.
@DanFosing Hi, looks like there are still some errors
File "flash-linear-attention/fla/models/mamba2/modeling_mamba2.py", line 607, in forward
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "flash-linear-attention/fla/models/mamba2/modeling_mamba2.py", line 333, in cuda_kernels_forward
rmsnorm_weight=self.norm.weight,
^^^^^^^^^^^^^^^^
File "anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1709, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'GatedRMSNorm' object has no attribute 'weight'
@DanFosing You can check your impls via running
python -m benchmarks.benchmark_training_throughput --name mamba2
Also it would be better to beautify your code style via pre-commit
to make it follow PEP8 guidelines.
You can refer to https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/layers/gla.py#L219-L228 for our impls of gate norm. I think there is no need to further wrap it with another GatedNorm
.
@DanFosing Appreciate your hard work and quick response.
Are you sure it works? When I tried it on kaggle I got assertion errors but it often has some weird problems so it may be an issue with kaggle (I can't test it on my pc right now)
Could you paste more detailed infos?
@DanFosing Here is the output
$ python -m benchmarks.benchmark_training_throughput --name mamba2
Initializing mamba2 model from the config:
Mamba2Config {
"bos_token_id": 1,
"chunk_size": 256,
"conv_kernel": 4,
"eos_token_id": 2,
"expand": 2,
"fuse_cross_entropy": true,
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.1,
"layer_norm_epsilon": 1e-05,
"model_type": "mamba2",
"n_groups": 8,
"norm_before_gate": true,
"num_heads": 64,
"num_hidden_layers": 48,
"pad_token_id": 0,
"rescale_prenorm_residual": false,
"residual_in_fp32": true,
"rms_norm": true,
"state_size": 128,
"tie_word_embeddings": false,
"time_step_floor": 0.0001,
"time_step_limit": [
0.0,
Infinity
],
"time_step_max": 0.1,
"time_step_min": 0.001,
"time_step_rank": 128,
"transformers_version": "4.43.3",
"use_bias": false,
"use_cache": true,
"use_conv_bias": true,
"vocab_size": 32000
}
Mamba2ForCausalLM(
(backbone): Mamba2Model(
(embeddings): Embedding(32000, 2048)
(layers): ModuleList(
(0-47): 48 x Mamba2Block(
(norm): RMSNorm(2048, eps=1e-05)
(mixer): Mamba2Mixer(
(act): SiLU()
(conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)
(in_proj): Linear(in_features=2048, out_features=10304, bias=False)
(norm): FusedRMSNormSwishGate(4096, eps=1e-05)
(out_proj): Linear(in_features=4096, out_features=2048, bias=False)
)
)
)
(norm_f): RMSNorm(2048, eps=1e-05)
)
(lm_head): Linear(in_features=2048, out_features=32000, bias=False)
)
Number of parameters in total: 1548430336 (1.44GiB)
Allocated memory after initialization: 2.89GiB
Max memory allocated: 42.01GiB: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:11<00:00, 4.44s/it]
Thoughput: 32048.92 tokens/s: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:16<00:00, 1.96it/s]
Turns out it was just a problem with T4 not supporting some things from causal_conv1d package, there is the same issue with mamba 1 and 2 from mamba-ssm. I don't remember what it was but there was some workaround that fixed it if you turned off some low memory mode in pytorch or something. Btw torch.compile support was added to mamba-1 but it requires some changes in mambacache and mamba modeling, I will try to implement it when I have some time.
One question on this PR:
From the FLA paper, both mamba-1 and mamba-2 can be written as Gated linear attention formulation:
So, the SSM part of Mamba-2 (excluding conv1d, normalization, ...) should permit the same "Chunkwise Parallel Form" as advocated by the FLA paper, no?
Then, the modeling_mamba2.py
implementation here shouldn't need to import mamba_chunk_scan_combined
from original mamba_ssm
repository, right? It can use a custom kernel similar to ops/gla/chunk_fuse.py
in this repo.
Otherwise, the modeling_mamba2.py
in current PR only uses custom RMSNorm
/FusedRMSNormSwishGate
kernels, while the SSM part is no different from original mamba_ssm
repository. Then the performance will remain largely the same as original repo, and you cannot tell whether FLA formulation is faster...
The mamba-2 blog did mention that FLA chunkwise parallel "turns out to be essentially equivalent to the SSD algorithm specialized to a restricted case"
Will the FLA formulation be just identical to the mamba_chunk_scan_combined
in original mamba-2 code? Or there is still some chance to improve on the original ver?
Indeed I think it can be made faster if a kernel similar to chunk_fuse.py is used (maybe it would be possible to just modify gla one as both GLA and Mamba-2 are extremely similar to each other). Unfortunately I'm not familar with triton so I can't really do it, I think it would be best if you made an issue or something, unless some of FLA devs replies there. Current implementation is a lot like mamba-1 implementation, both are mostly like original mamba-ssm ones, with just some minor speed up thanks to FusedCrossEntropyLoss and custom RMSNorm and those are compatible with huggingface transformers.
If you take a look at this paper: https://arxiv.org/pdf/2406.06484 you can also see that in terms of recurrence and memory read-out mamba-2 is very similar to RetNet and Linear Attention:
I think it would be best if you made an issue or something, unless some of FLA devs replies there.
Will look into it!
Also @yzhangcs for any suggestions -- like, which existing triton kernel is the best starting point to re-implement the SSM part of mamba2? From the table above, kernels for RetNet should be the most close ones?
The mamba-2 blog also mentions the close connection between RetNet and Mamba-2 (under their SSD framework):
Prior examples include the original linear attention as well as the recent Retentive Network (RetNet) model [18] . These can be viewed as direct special cases of SSD.
@learning-chip Hi, you may refer to simple gla by @sustcsonglin which provides data-dependent decay upon RetNet.
Modified version of https://github.com/huggingface/transformers/tree/add_codestral_mamba2/src/transformers/models/mamba2 (https://github.com/huggingface/transformers/pull/32080) to work with FLA. I haven't tested if it works yet, but I'm pretty sure it will work. It can probably be made a bit faster by implementing gated RMSNorm utilizing SiLU.