This is a highly experimental implementation of Mamba2
[1] that is compatible with the transformers
library by Hugging Face
[2]. It is only supporting the pure Mamba2 block which means the hybrid variants with Attention and/or MLP are not available.
NOTE: You can use this repo to use Mamba2
based models with all optimisation paths:
NOTE: I'm not affiliated with the original authors of Mamba2 or Hugging Face.
I won't distribute a pypi package, but you can use it as package by cloning the repo and installing it at root:
git clone https://github.com/vasqu/mamba2-torch.git
cd mamba2-torch
pip install .
To use the "fastest" path, you need to install the causal-conv1d package separately.
To use any pretrained Mamba2
model you need a compatible format of the respective model. You have two options:
# example usage to download mamba2-130m
# 1st argument = parameter count, 2nd argument = directory to save the converted model to
./download_mamba2.sh 130m ../models
# example usage to download and convert mamba2-130m
# 1st argument = parameter count, 2nd argument = directory to save the converted model to
./convert_mamba2.sh 130m ../models
Now you can use the converted model the following way.
from transformers import AutoTokenizer
from mamba2_torch import Mamba2Model, Mamba2ForCausalLM, Mamba2Config
device = "cuda"
mamba2_hf_path = "<path-to-converted-model>"
model = Mamba2ForCausalLM.from_pretrained(mamba2_hf_path, local_files_only=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(mamba2_hf_path, local_files_only=True)
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device)
# expected output (130m): `["Hey how are you doing?\n\nI'm in the middle of a project"]`
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
Some optional features to give more control over the model:
from transformers import AutoTokenizer
from mamba2_torch import Mamba2Model, Mamba2ForCausalLM, Mamba2Config
mamba2_hf_path = "<path-to-converted-model>"
# flag to enable / disable using triton kernels
# --> pure PyTorch implementation will be used instead
config = Mamba2Config.from_pretrained(mamba2_hf_path, local_files_only=True)
config.use_triton_kernels = False
model = Mamba2ForCausalLM.from_pretrained(mamba2_hf_path, config=config, local_files_only=True)
...
from transformers import AutoTokenizer
from mamba2_torch import Mamba2Model, Mamba2ForCausalLM, Mamba2Config
device = "cuda"
mamba2_hf_path = "<path-to-converted-model>"
# flag to enable / disable outputting last SSM states
config = Mamba2Config.from_pretrained(mamba2_hf_path, local_files_only=True)
config.output_last_ssm_states = True
model = Mamba2ForCausalLM.from_pretrained(mamba2_hf_path, config=config, local_files_only=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(mamba2_hf_path, local_files_only=True)
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device)
# or do it in the forward pass directly
out = model(input_ids, output_last_ssm_states=True)
import torch
from transformers import AutoTokenizer
from mamba2_torch import Mamba2Model, Mamba2ForCausalLM, Mamba2Config
device = "cuda"
mamba2_hf_path = "<path-to-converted-model>"
model = Mamba2ForCausalLM.from_pretrained(mamba2_hf_path, local_files_only=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(mamba2_hf_path, local_files_only=True)
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device)
# creating random initial states
config = Mamba2Config.from_pretrained(mamba2_hf_path, local_files_only=True)
initial_states = [
torch.randn(size=(input_ids.shape[0], config.num_heads, config.head_dim, config.state_size)).to("cuda")
for _ in range(config.num_hidden_layers)
]
# don't pass an initial state to the 5th block
initial_states[4] = None
# pass it in the forward call
out = model(input_ids, initial_states=initial_states)
D
residual connection. A small test that checks roughly equal outputs is over here.mamba_ssm
. It should work with the triton kernels here as well.( (d_model * expand) / headdim ) % 8 == 0
.tie_embedding_weights
flag in the config is probably enforced in any case. Not too interested in digging into this but open for PRs.[1] Mamba2
@inproceedings{mamba2,
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
author={Dao, Tri and Gu, Albert},
booktitle={International Conference on Machine Learning (ICML)},
year={2024}
}
[2] Hugging Face
@inproceedings{wolf-etal-2020-transformers,
title = "Transformers: State-of-the-Art Natural Language Processing",
author = "Thomas Wolf and Lysandre Debut and Victor Sanh and Julien Chaumond and Clement Delangue and Anthony Moi and Pierric Cistac and Tim Rault and Rémi Louf and Morgan Funtowicz and Joe Davison and Sam Shleifer and Patrick von Platen and Clara Ma and Yacine Jernite and Julien Plu and Canwen Xu and Teven Le Scao and Sylvain Gugger and Mariama Drame and Quentin Lhoest and Alexander M. Rush",
booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: System Demonstrations",
month = oct,
year = "2020",
address = "Online",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/2020.emnlp-demos.6",
pages = "38--45"
}