alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
960 stars 86 forks source link

Jamba PyTorch #24

Closed alxndrTL closed 5 months ago

alxndrTL commented 5 months ago

This PR brings the Jamba architecture implementation in PyTorch to the repo.

All the Jamba-related stuff is located in the jamba.py file. Precisely, this file contains :

This Jamba implementation was used to train a 9.5M Jamba in the OthelloGPT experiment. Check the results here : Othello Mamba repo.

Additionally, mamba.py now supports using the official CUDA code as a backend. You can thus train a Mamba or a Jamba model with this repo, enjoy an easy to read implementation and still have the best performances. This choice was motivated by the goal to be able to easily train in simple PyTorch a Jamba model in an efficient way. As of today, the Jamba architecture is only available through the transformers library.

Speaking of the implementation of Jamba in the transformers library, is it located in a 2500-lines file, and thus not that easy to read. jamba.py is a 530 effective lines long file and I think is more easy to read and tinker with. I chose to put everything (attention + MoE) in the jamba.py file in order to not clutter the repo with additional files like attention.py, moe.py or whatever...

Next step : Jamba in MLX!