johnma2006 / mamba-minimal

Simple, minimal implementation of the Mamba SSM in one file of PyTorch.
Apache License 2.0
2.54k stars 188 forks source link

Memory issue due to A and B matrix computation #25

Open anupsingh15 opened 6 months ago

anupsingh15 commented 6 months ago

Hi, Thanks for providing the Mamba implementation. I would like to know if there is any workaround in the efficient computation of deltaA and deltaB_u that can avoid the GPU memory running out issue. The following are the parameters I used to create the Mamba instance:

d_model: 1024
n_layer: 4   
d_state: int = 1024
expand: int = 2

The other parameters are set to their default values.

It results in a model of ~60M parameters. However, I run out of memory (max GPU memory= 24 GB) when I train with a batch size of 256 or even as low as 64 and this probably happens due to large matrix computations for deltaA and deltaB_u.

shim0114 commented 6 months ago

I also have this issue...!

johnma2006 commented 5 months ago

This repo is mostly meant for educational purpose, and I would suggest using the official repo to do any training: https://github.com/state-spaces/mamba

XZJIsme commented 5 months ago

I also met this OOM problem lately, but not when using this repo's codes. You may refer to this question on stackoverflow.