jxiw / MambaInLlama

Official Repository of The Mamba in the Llama: Distilling and Accelerating Hybrid Models
https://arxiv.org/abs/2408.15237
Apache License 2.0
169 stars 12 forks source link

Why is it that the config file doesn't seem to cover all the parameters when I apply the Mamba2InLlama3B_Half_DPO model? #14

Open XDjiang25 opened 1 week ago

XDjiang25 commented 1 week ago

Your idea is great but I seem to be having trouble reproducing it.:RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM: Missing key(s) in state_dict: "model.layers.0.mamba.dt_bias", "model.layers.0.mamba.A_log", "model.layers.0.mamba.D", "model.layers.0.mamba.in_proj.weight", "model.layers.0.mamba.conv1d.weight", "model.layers.0.mamba.conv1d.bias", "model.layers.0.mamba.norm.weight", "model.layers.0.mamba.out_proj.weight", "model.layers.2.mamba.dt_bias", "model.layers.2.mamba.A_log", "model.layers.2.mamba.D", "model.layers.2.mamba.in_proj.weight", "model.layers.2.mamba.conv1d.weight", "model.layers.2.mamba.conv1d.bias", "model.layers.2.mamba.norm.weight", "model.layers.2.mamba.out_proj.weight", "model.layers.4.mamba.dt_bias", "model.layers.4.mamba.A_log", "model.layers.4.mamba.D", "model.layers.4.mamba.in_proj.weight", "model.layers.4.mamba.conv1d.weight", "model.layers.4.mamba.conv1d.bias", "model.layers.4.mamba.norm.weight", "model.layers.4.mamba.out_proj.weight", "model.layers.6.mamba.dt_bias", "model.layers.6.mamba.A_log", "model.layers.6.mamba.D", "model.layers.6.mamba.in_proj.weight", "model.layers.6.mamba.conv1d.weight", "model.layers.6.mamba.conv1d.bias", "model.layers.6.mamba.norm.weight", "model.layers.6.mamba.out_proj.weight", "model.layers.8.mamba.dt_bias", "model.layers.8.mamba.A_log", "model.layers.8.mamba.D", "model.layers.8.mamba.in_proj.weight", "model.layers.8.mamba.conv1d.weight", "model.layers.8.mamba.conv1d.bias", "model.layers.8.mamba.norm.weight", "model.layers.8.mamba.out_proj.weight", "model.layers.10.mamba.dt_bias", "model.layers.10.mamba.A_log", "model.layers.10.mamba.D", "model.layers.10.mamba.in_proj.weight", "model.layers.10.mamba.conv1d.weight", "model.layers.10.mamba.conv1d.bias", "model.layers.10.mamba.norm.weight", "model.layers.10.mamba.out_proj.weight", "model.layers.12.mamba.dt_bias", "model.layers.12.mamba.A_log", "model.layers.12.mamba.D", "model.layers.12.mamba.in_proj.weight", "model.layers.12.mamba.conv1d.weight", "model.layers.12.mamba.conv1d.bias", "model.layers.12.mamba.norm.weight", "model.layers.12.mamba.out_proj.weight", "model.layers.14.mamba.dt_bias", "model.layers.14.mamba.A_log", "model.layers.14.mamba.D", "model.layers.14.mamba.in_proj.weight", "model.layers.14.mamba.conv1d.weight", "model.layers.14.mamba.conv1d.bias", "model.layers.14.mamba.norm.weight", "model.layers.14.mamba.out_proj.weight", "model.layers.16.mamba.dt_bias", "model.layers.16.mamba.A_log", "model.layers.16.mamba.D", "model.layers.16.mamba.in_proj.weight", "model.layers.16.mamba.conv1d.weight", "model.layers.16.mamba.conv1d.bias", "model.layers.16.mamba.norm.weight", "model.layers.16.mamba.out_proj.weight", "model.layers.18.mamba.dt_bias", "model.layers.18.mamba.A_log", "model.layers.18.mamba.D", "model.layers.18.mamba.in_proj.weight", "model.layers.18.mamba.conv1d.weight", "model.layers.18.mamba.conv1d.bias", "model.layers.18.mamba.norm.weight", "model.layers.18.mamba.out_proj.weight", "model.layers.20.mamba.dt_bias", "model.layers.20.mamba.A_log", "model.layers.20.mamba.D", "model.layers.20.mamba.in_proj.weight", "model.layers.20.mamba.conv1d.weight", "model.layers.20.mamba.conv1d.bias", "model.layers.20.mamba.norm.weight", "model.layers.20.mamba.out_proj.weight", "model.layers.22.mamba.dt_bias", "model.layers.22.mamba.A_log", "model.layers.22.mamba.D", "model.layers.22.mamba.in_proj.weight", "model.layers.22.mamba.conv1d.weight", "model.layers.22.mamba.conv1d.bias", "model.layers.22.mamba.norm.weight", "model.layers.22.mamba.out_proj.weight", "model.layers.24.mamba.dt_bias", "model.layers.24.mamba.A_log", "model.layers.24.mamba.D", "model.layers.24.mamba.in_proj.weight", "model.layers.24.mamba.conv1d.weight", "model.layers.24.mamba.conv1d.bias", "model.layers.24.mamba.norm.weight", "model.layers.24.mamba.out_proj.weight", "model.layers.26.mamba.dt_bias", "model.layers.26.mamba.A_log", "model.layers.26.mamba.D", "model.layers.26.mamba.in_proj.weight", "model.layers.26.mamba.conv1d.weight", "model.layers.26.mamba.conv1d.bias", "model.layers.26.mamba.norm.weight", "model.layers.26.mamba.out_proj.weight", "lm_head.weight".

jxiw commented 1 week ago

Could you please let me know what is you command of this error?

XDjiang25 commented 1 week ago

Could you please let me know what is you command of this error? I only downloaded Mamba2InLlama_0_50 locally, and I got an error loading the model as the command in Generation Example: pretrained_model_name = "/remote-home/nas/Mamba2InLlama_0_50" model = MambaTransformerHybridModelWrapper.from_pretrained(pretrained_model_name, torch_dtype=torch.bfloat16)

XDjiang25 commented 1 week ago

Could you please let me know what is you command of this error? I only downloaded Mamba2InLlama_0_50 locally, and I got an error loading the model as the command in Generation Example: pretrained_model_name = "/remote-home/nas/Mamba2InLlama_0_50" model = MambaTransformerHybridModelWrapper.from_pretrained(pretrained_model_name, torch_dtype=torch.bfloat16) Then when loading ckpt in model it reports an error, there is no dt_bias and A_log and since attn_layers in Mamba_config is odd (1,3,5...) in Mamba_config, so even layers (0,2,4...) attn.k_proj.weight, attn.v_proj.weight, etc. cannot be merged into in_proj.weight or out_proj.weight: File "/remote-home/nas/MambaInLlama-main/mamba2_inference/hybrid_wrapper.py", line 107, in init self.model.load_state_dict(ckpt)

jxiw commented 1 week ago

This is a bit strange. The error suggests that you’re trying to load a hybrid mamba checkpoint into a LlamaForCausalLM model, which is incompatible. The error message says: RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM.

I believe we may not be using LlamaForCausalLM to load this checkpoint.

XDjiang25 commented 1 week ago

This is a bit strange. The error suggests that you’re trying to load a hybrid mamba checkpoint into a LlamaForCausalLM model, which is incompatible. The error message says: RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM.

I believe we may not be using LlamaForCausalLM to load this checkpoint.

Even LlamaForSequenceClassification will also get the same error

jxiw commented 1 week ago

Thank you for this question. Just to be clear, we are not using LlamaForCausalLM, LlamaForSequenceClassification, or any Llama-based models at all. You should obtain a hybrid Mamba LM before loading this checkpoint. So just print self.model or model and check whether it is a hybrid Mamba model before you load using a checkpoint. Thank you.

XDjiang25 commented 4 days ago

I really appreciate your reply! Your ideas are great.