Open XDjiang25 opened 1 week ago
Could you please let me know what is you command of this error?
Could you please let me know what is you command of this error? I only downloaded Mamba2InLlama_0_50 locally, and I got an error loading the model as the command in Generation Example: pretrained_model_name = "/remote-home/nas/Mamba2InLlama_0_50" model = MambaTransformerHybridModelWrapper.from_pretrained(pretrained_model_name, torch_dtype=torch.bfloat16)
Could you please let me know what is you command of this error? I only downloaded Mamba2InLlama_0_50 locally, and I got an error loading the model as the command in Generation Example: pretrained_model_name = "/remote-home/nas/Mamba2InLlama_0_50" model = MambaTransformerHybridModelWrapper.from_pretrained(pretrained_model_name, torch_dtype=torch.bfloat16) Then when loading ckpt in model it reports an error, there is no dt_bias and A_log and since attn_layers in Mamba_config is odd (1,3,5...) in Mamba_config, so even layers (0,2,4...) attn.k_proj.weight, attn.v_proj.weight, etc. cannot be merged into in_proj.weight or out_proj.weight: File "/remote-home/nas/MambaInLlama-main/mamba2_inference/hybrid_wrapper.py", line 107, in init self.model.load_state_dict(ckpt)
This is a bit strange. The error suggests that you’re trying to load a hybrid mamba checkpoint into a LlamaForCausalLM
model, which is incompatible. The error message says: RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM
.
I believe we may not be using LlamaForCausalLM
to load this checkpoint.
This is a bit strange. The error suggests that you’re trying to load a hybrid mamba checkpoint into a
LlamaForCausalLM
model, which is incompatible. The error message says: RuntimeError: Error(s) in loading state_dict forLlamaForCausalLM
.I believe we may not be using
LlamaForCausalLM
to load this checkpoint.
Even LlamaForSequenceClassification will also get the same error
Thank you for this question. Just to be clear, we are not using LlamaForCausalLM, LlamaForSequenceClassification, or any Llama-based models at all. You should obtain a hybrid Mamba LM before loading this checkpoint. So just print self.model or model and check whether it is a hybrid Mamba model before you load using a checkpoint. Thank you.
I really appreciate your reply! Your ideas are great.
Your idea is great but I seem to be having trouble reproducing it.:RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM: Missing key(s) in state_dict: "model.layers.0.mamba.dt_bias", "model.layers.0.mamba.A_log", "model.layers.0.mamba.D", "model.layers.0.mamba.in_proj.weight", "model.layers.0.mamba.conv1d.weight", "model.layers.0.mamba.conv1d.bias", "model.layers.0.mamba.norm.weight", "model.layers.0.mamba.out_proj.weight", "model.layers.2.mamba.dt_bias", "model.layers.2.mamba.A_log", "model.layers.2.mamba.D", "model.layers.2.mamba.in_proj.weight", "model.layers.2.mamba.conv1d.weight", "model.layers.2.mamba.conv1d.bias", "model.layers.2.mamba.norm.weight", "model.layers.2.mamba.out_proj.weight", "model.layers.4.mamba.dt_bias", "model.layers.4.mamba.A_log", "model.layers.4.mamba.D", "model.layers.4.mamba.in_proj.weight", "model.layers.4.mamba.conv1d.weight", "model.layers.4.mamba.conv1d.bias", "model.layers.4.mamba.norm.weight", "model.layers.4.mamba.out_proj.weight", "model.layers.6.mamba.dt_bias", "model.layers.6.mamba.A_log", "model.layers.6.mamba.D", "model.layers.6.mamba.in_proj.weight", "model.layers.6.mamba.conv1d.weight", "model.layers.6.mamba.conv1d.bias", "model.layers.6.mamba.norm.weight", "model.layers.6.mamba.out_proj.weight", "model.layers.8.mamba.dt_bias", "model.layers.8.mamba.A_log", "model.layers.8.mamba.D", "model.layers.8.mamba.in_proj.weight", "model.layers.8.mamba.conv1d.weight", "model.layers.8.mamba.conv1d.bias", "model.layers.8.mamba.norm.weight", "model.layers.8.mamba.out_proj.weight", "model.layers.10.mamba.dt_bias", "model.layers.10.mamba.A_log", "model.layers.10.mamba.D", "model.layers.10.mamba.in_proj.weight", "model.layers.10.mamba.conv1d.weight", "model.layers.10.mamba.conv1d.bias", "model.layers.10.mamba.norm.weight", "model.layers.10.mamba.out_proj.weight", "model.layers.12.mamba.dt_bias", "model.layers.12.mamba.A_log", "model.layers.12.mamba.D", "model.layers.12.mamba.in_proj.weight", "model.layers.12.mamba.conv1d.weight", "model.layers.12.mamba.conv1d.bias", "model.layers.12.mamba.norm.weight", "model.layers.12.mamba.out_proj.weight", "model.layers.14.mamba.dt_bias", "model.layers.14.mamba.A_log", "model.layers.14.mamba.D", "model.layers.14.mamba.in_proj.weight", "model.layers.14.mamba.conv1d.weight", "model.layers.14.mamba.conv1d.bias", "model.layers.14.mamba.norm.weight", "model.layers.14.mamba.out_proj.weight", "model.layers.16.mamba.dt_bias", "model.layers.16.mamba.A_log", "model.layers.16.mamba.D", "model.layers.16.mamba.in_proj.weight", "model.layers.16.mamba.conv1d.weight", "model.layers.16.mamba.conv1d.bias", "model.layers.16.mamba.norm.weight", "model.layers.16.mamba.out_proj.weight", "model.layers.18.mamba.dt_bias", "model.layers.18.mamba.A_log", "model.layers.18.mamba.D", "model.layers.18.mamba.in_proj.weight", "model.layers.18.mamba.conv1d.weight", "model.layers.18.mamba.conv1d.bias", "model.layers.18.mamba.norm.weight", "model.layers.18.mamba.out_proj.weight", "model.layers.20.mamba.dt_bias", "model.layers.20.mamba.A_log", "model.layers.20.mamba.D", "model.layers.20.mamba.in_proj.weight", "model.layers.20.mamba.conv1d.weight", "model.layers.20.mamba.conv1d.bias", "model.layers.20.mamba.norm.weight", "model.layers.20.mamba.out_proj.weight", "model.layers.22.mamba.dt_bias", "model.layers.22.mamba.A_log", "model.layers.22.mamba.D", "model.layers.22.mamba.in_proj.weight", "model.layers.22.mamba.conv1d.weight", "model.layers.22.mamba.conv1d.bias", "model.layers.22.mamba.norm.weight", "model.layers.22.mamba.out_proj.weight", "model.layers.24.mamba.dt_bias", "model.layers.24.mamba.A_log", "model.layers.24.mamba.D", "model.layers.24.mamba.in_proj.weight", "model.layers.24.mamba.conv1d.weight", "model.layers.24.mamba.conv1d.bias", "model.layers.24.mamba.norm.weight", "model.layers.24.mamba.out_proj.weight", "model.layers.26.mamba.dt_bias", "model.layers.26.mamba.A_log", "model.layers.26.mamba.D", "model.layers.26.mamba.in_proj.weight", "model.layers.26.mamba.conv1d.weight", "model.layers.26.mamba.conv1d.bias", "model.layers.26.mamba.norm.weight", "model.layers.26.mamba.out_proj.weight", "lm_head.weight".