jxiw / MambaInLlama

Official Repository of The Mamba in the Llama: Distilling and Accelerating Hybrid Models
https://arxiv.org/abs/2408.15237
Apache License 2.0
169 stars 12 forks source link

Mamba Model initialisation #8

Closed aashay-sarvam closed 1 month ago

aashay-sarvam commented 1 month ago

How is the mamba model being initialised?

Screenshot 2024-09-11 at 7 08 06 PM

In the paper it is suggested the model at the start gets some initialisation from the transformer, but I don't see it in the code.

jxiw commented 1 month ago

Hi aashay,

Thanks for raising this problem. Here is the answer.

If you follow of 3 steps training. Assume in the first step, we are going to distill models using pseudo label distillation. We set the init_with_kqvo to true. This part will be executed. https://github.com/jxiw/MambaInLlama/blob/main/mamba/hybrid_wrapper.py#L38-L46

If you skip the first step, only do SFT and DPO (the performance drops a bit). We still initialize linear projection layers of SSM using this https://github.com/jxiw/MambaInLlama/blob/main/train_mamba/train_sft.py#L198-L199 as init_with_kqvo is set to True.

Best, Junxiong