state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.97k stars 1.1k forks source link

When to use ngroups in mamba-v2? #442

Closed ScottHoang closed 3 months ago

ScottHoang commented 3 months ago

In my use case, I am applying Pytorch's FSDP on mamba over 8 GPUs. Should ngroups be increased to 8?

tridao commented 3 months ago

No it's not necessary. You should increase ngroups if you're using tensor parallel.

ScottHoang commented 3 months ago

Thank you! I appreciate the answer.

Also, unlike the original mamba, mamba2-2.7b (on HF) does not contain a Feed-forward network (intermediate value is set to 0) on the hugging face. Is this correct?

tridao commented 3 months ago

Neither mamba1 or 2 have feedforward. The default is d_intermediate = 0.

ScottHoang commented 3 months ago

mamba 1 : https://huggingface.co/state-spaces/mamba-2.8b-hf/blob/main/config.json has an intermediate value of twice the model dimension.

tridao commented 3 months ago

That's the HF config. I'm not familiar with how it's impelmented in HF transformer. They might use "intermediate_size" to mean sth else (not the feedforward). Here's the config used for the implementation in this repo: https://huggingface.co/state-spaces/mamba-2.8b/blob/main/config.json

ScottHoang commented 3 months ago

Got it! Thank you!