This PR brings the Jamba architecture implementation in PyTorch to the repo.
All the Jamba-related stuff is located in the jamba.py file. Precisely, this file contains :
the attention computation (GQA, which uses torch.nn.functional.scaled_dot_product_attention)
the sparse MoE MLP layer
and the Jamba architecture which uses the MambaBlock from mamba.py and the things listed just above.
This Jamba implementation was used to train a 9.5M Jamba in the OthelloGPT experiment. Check the results here : Othello Mamba repo.
Additionally, mamba.py now supports using the official CUDA code as a backend.
You can thus train a Mamba or a Jamba model with this repo, enjoy an easy to read implementation and still have the best performances. This choice was motivated by the goal to be able to easily train in simple PyTorch a Jamba model in an efficient way. As of today, the Jamba architecture is only available through the transformers library.
Speaking of the implementation of Jamba in the transformers library, is it located in a 2500-lines file, and thus not that easy to read. jamba.py is a 530 effective lines long file and I think is more easy to read and tinker with.
I chose to put everything (attention + MoE) in the jamba.py file in order to not clutter the repo with additional files like attention.py, moe.py or whatever...
This PR brings the Jamba architecture implementation in PyTorch to the repo.
All the Jamba-related stuff is located in the
jamba.py
file. Precisely, this file contains :torch.nn.functional.scaled_dot_product_attention
)mamba.py
and the things listed just above.This Jamba implementation was used to train a 9.5M Jamba in the OthelloGPT experiment. Check the results here : Othello Mamba repo.
Additionally,
mamba.py
now supports using the official CUDA code as a backend. You can thus train a Mamba or a Jamba model with this repo, enjoy an easy to read implementation and still have the best performances. This choice was motivated by the goal to be able to easily train in simple PyTorch a Jamba model in an efficient way. As of today, the Jamba architecture is only available through thetransformers
library.Speaking of the implementation of Jamba in the
transformers
library, is it located in a 2500-lines file, and thus not that easy to read.jamba.py
is a 530 effective lines long file and I think is more easy to read and tinker with. I chose to put everything (attention + MoE) in thejamba.py
file in order to not clutter the repo with additional files like attention.py, moe.py or whatever...Next step : Jamba in MLX!