alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
943 stars 86 forks source link

huge huge memory usage!! #9

Open eisneim opened 8 months ago

eisneim commented 8 months ago

i find that the pscan method used in this Mamba implementation use huge amount of memory! any idea how to reduce memory consumption? or replace the pscan method with other implementation??

great thanks!

alxndrTL commented 8 months ago

Indeed, it is normal that this version uses a lot of memory as it doesn't use the recomputation technique described in the paper (see this on the README for more information). As of now, if you can use the default CUDA implementation, go for it (if you have a recent enough NVIDIA GPU), as mamba.py is mostly designed for educational purposes. If not, you can look maybe here to implement the recomputation yourself. I'm working on a "performance update" and I hope that I will be able to include the recomputation technique.

Here is my benchmark for training a Mamba with d_model=512, n_layers=16, B=16, L=512 on a A100 80GB :

eisneim commented 8 months ago

@alxndrTL thanks! this is such a great work! i'm using M3 max macbook so CUDA implementation is not an option unfortunately; i was trying to use your mamba code with method described in paper vision mamba(Vim), it got the CLIP training working using mamba as text encoder, vision mamba as vision encoder, but comparing to Vit, the memory usage increased at least 5 times! maybe creating a custom operator using Metal would solve this issue? apple doc

looking forward to your "performance update"

alxndrTL commented 8 months ago

Yes, using Metal would be an option, but there is also MLX in the place now, I don't yet know if they are exclusive or not. I guess it will not be too long before MLX implements an efficient (kernel optimized, etc) pscan implementation. I will keep you in touch here for the performance update!

ali-shihab commented 5 months ago

Hey, any updates on this performance/pscan update? My 16GB m1 pro just crashed from loading mamba-2.8b for fine tuning on a classification task, so I'm looking forward to this!

alxndrTL commented 5 months ago

Hello @ali-shihab, the performance update has been pushed ≈2 months ago, but it only enables faster training, not memory usage. I just worked on Jamba so I'm taking some time off but I'm aware that it's the main problem of mamba.py right now (and kind of make it not usable for real scenarios)

ali-shihab commented 5 months ago

Hey @alxndrTL, thanks for the quick response. If I'm not mistaken, the performance update is for PyTorch, no? Correct me if I'm wrong.

Additionally, I've just scanned over jamba.py very quickly, and it seems everything in there can be implemented in MLX - do you think I could do this, or is there something that stopped you from being able to? If not, I'll see if I can implement a local version of it in MLX & clean it up for a PR once I have some time.

Also, that is completely understandable, enjoy your time off :)

alxndrTL commented 5 months ago

Yes the performance update is for PyTorch only, as the pscan works quite poorly on MLX as of right now (see the comments in https://github.com/alxndrTL/mamba.py/blob/main/mlx/pscan_mlx.py, tested with MLX in January)

For the Jamba implementation in MLX, that would be very welcome! The one sore point I see would be replacing F.scaled_dot_product_attention, I don't know if there is an implementation of FlashAttention in MLX, else you just have to do the attention computations by hand.

M-I commented 2 months ago

flash attention is not there (yet): https://github.com/ml-explore/mlx-examples/issues/724#issuecomment-2101093352