johnma2006 / mamba-minimal

Simple, minimal implementation of the Mamba SSM in one file of PyTorch.
Apache License 2.0
2.62k stars 191 forks source link

mamba-minimal

Simple, minimal implementation of Mamba in one file of PyTorch.

Featuring:

Does NOT include:

Demo

See demo.ipynb for examples of prompt completions.

from model import Mamba
from transformers import AutoTokenizer

model = Mamba.from_pretrained('state-spaces/mamba-370m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

generate(model, tokenizer, 'Mamba is the')

Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)

150 meters... 🫢 scary!

References

The Mamba architecture was introduced in Mamba: Linear-Time Sequence Modeling with Selective State Spaces by Albert Gu and Tri Dao.

The official implementation is here: https://github.com/state-spaces/mamba/tree/main