Closed aashay-sarvam closed 1 month ago
Hi aashay,
Thanks for raising this problem. Here is the answer.
If you follow of 3 steps training. Assume in the first step, we are going to distill models using pseudo label distillation. We set the init_with_kqvo to true. This part will be executed. https://github.com/jxiw/MambaInLlama/blob/main/mamba/hybrid_wrapper.py#L38-L46
If you skip the first step, only do SFT and DPO (the performance drops a bit). We still initialize linear projection layers of SSM using this https://github.com/jxiw/MambaInLlama/blob/main/train_mamba/train_sft.py#L198-L199 as init_with_kqvo is set to True.
Best, Junxiong
How is the mamba model being initialised?
In the paper it is suggested the model at the start gets some initialisation from the transformer, but I don't see it in the code.