state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.28k stars 1.13k forks source link

Quantization #133

Open arman-kazemi opened 10 months ago

arman-kazemi commented 10 months ago

Hi, Have you tried quantizing Mamba? Do you plan on releasing quantized versions? Can you share your thoughts on quantizing Mamba, given the sensitivity of the model's recurrent dynamics? Thanks

tridao commented 10 months ago

We have not tried quantization, it's an open question. Would be very interesting to understand how sensitive the model is to the SSM params. E.g. I could imagine quantizing the nn.Linear weights but keep the SSM params and states in high precision.

radna0 commented 5 months ago

I would love an update on this

hychiang-git commented 4 months ago

Hello, we have some initial results to share, but it is still under reviewing. Please see our pre-viewed version at https://hychiang.info/projects/quamba/

kmheckel commented 4 months ago

Here's a paper being presented at the Next-Generation Sequence Modeling Workshop at ICML next week: https://arxiv.org/abs/2406.09477

The takeaway is that for quantization aware training and inference on LRA, most parameters can be quantized to below uint8, but the the recurrent matrix A/lambda is the most sensitive and performance dramatically changes under 8 bits.

This recent preprint might also be of interest: https://arxiv.org/abs/2407.12397

dustydecapod commented 3 months ago

Here's a paper being presented at the Next-Generation Sequence Modeling Workshop at ICML next week: https://arxiv.org/abs/2406.09477

The takeaway is that for quantization aware training and inference on LRA, most parameters can be quantized to below uint8, but the the recurrent matrix A/lambda is the most sensitive and performance dramatically changes under 8 bits.

This recent preprint might also be of interest: https://arxiv.org/abs/2407.12397

So basically mixed precision of int4 for most of the weights, then int8 for matrix A/lambda, is reasonable?

kmheckel commented 3 months ago

Yup, this is similar to the more recent ExpertsInt8 from Jamba1.5 where the MLPs in the MOEs are quantized: https://www.ai21.com/blog/announcing-jamba-model-family

But yeah, generally the order of sensitivity is MLPs < non-recurrent parameters inside SSM block < A/Lambda/recurrent matrix in SSM block. Not sure about the impact to the other mechanisms in Mamba but it's a start.