state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.37k stars 1.04k forks source link

Query Regarding Mamba Model Performance Tuning #22

Open yihong1120 opened 9 months ago

yihong1120 commented 9 months ago

Dear Mamba Contributors,

I hope this message finds you well. I am in the process of utilising the Mamba state space architecture for a language modelling task and have been highly impressed with the innovative approach adopted in this project.

However, whilst implementing the pretrained Mamba models provided, I have observed anomalous behaviour concerning model stability during training. Despite following the recommended guidelines and ensuring that my system aligns with the specified prerequisites (i.e. PyTorch 1.12+, CUDA 11.6+, Linux-based system with NVIDIA GPU), the SSMs appear to be sensitive to the recurrent dynamics, leading to unpredicted fluctuations.

It is indicated within the 'Troubleshooting' section that mixed precision via PyTorch's AMP maintains model parameters in float32 and casts to half precision when necessary. Nevertheless, might there be an alternative approach or further suggestions you could extend to enhance the model's stability? It would be most beneficial if there were additional insights into configurations that might mitigate the aforestated stability issues.

Moreover, I am curious to enquire if there are any plans for updated releases or patches that could possibly offer improved robustness or address the described concerns.

I would like to express my gratitude for making such a groundbreaking model accessible to the public and for your commitment to advancing the field of machine learning. I look forward to your guidance on the matter and any subsequent versions of Mamba that may further polish its performance.

Thank you for your time and assistance.

Warm regards,

yihong1120

albertfgu commented 9 months ago

I can't tell from your message: which training framework are you using and is it properly keeping parameters in fp32?

I will say that since the training of the original models, we have noticed potential instabilities sometimes and have a fix that seems to address it most of the time. Essentially, we insert an extra LayerNorm/RMSNorm in a specific spot. You might already be able to run this yourself by going into the model details, but it's also on our roadmap to provide a version with this functionality.