I am following along "Example transformer models (decoder-only LLMs)" and am having issues with checkpoint mapping. I am get this key error:
RuntimeError: Error(s) in loading state_dict for Phi2:
Missing key(s) in state_dict: "transformer_blocks.0.atten_func.qkv_projection.weight", "transformer_blocks.0.atten_func.qkv_projection.bias", "transformer_blocks.0.atten_func.output_projection.weight", "transformer_blocks.0.atten_func.output_projection.bias", "transformer_blocks.1.atten_func.qkv_projection.weight", "transformer_blocks.1.atten_func.qkv_projection.bias", "transformer_blocks.1.atten_func.output_projection.weight", "transformer_blocks.1.atten_func.output_projection.bias", "transformer_blocks.2.atten_func.qkv_projection.weight", "transformer_blocks.2.atten_func.qkv_projection.bias", "transformer_blocks.2.atten_func.output_projection.weight", "transformer_blocks.2.atten_func.output_projection.bias", "transformer_blocks.3.atten_func.qkv_projection.weight", "transformer_blocks.3.atten_func.qkv_projection.bias", "transformer_blocks.3.atten_func.output_projection.weight", "transformer_blocks.3.atten_func.output_projection.bias", "transformer_blocks.4.atten_func.qkv_projection.weight", "transformer_blocks.4.atten_func.qkv_projection.bias", "transformer_blocks.4.atten_func.output_projection.weight", "transformer_blocks.4.atten_func.output_projection.bias", "transformer_blocks.5.atten_func.qkv_projection.weight", "transformer_blocks.5.atten_func.qkv_projection.bias", "transformer_blocks.5.atten_func.output_projection.weight", "transformer_blocks.5.atten_func.output_projection.bias", "transformer_blocks.6.atten_func.qkv_projection.weight", "transformer_blocks.6.atten_f...
Unexpected key(s) in state_dict: "transformer_blocks.0.atten_func.attn.weight", "transformer_blocks.0.atten_func.attn.bias", "transformer_blocks.0.atten_func.proj.weight", "transformer_blocks.0.atten_func.proj.bias", "transformer_blocks.1.atten_func.attn.weight", "transformer_blocks.1.atten_func.attn.bias", "transformer_blocks.1.atten_func.proj.weight", "transformer_blocks.1.atten_func.proj.bias", "transformer_blocks.2.atten_func.attn.weight", "transformer_blocks.2.atten_func.attn.bias", "transformer_blocks.2.atten_func.proj.weight", "transformer_blocks.2.atten_func.proj.bias", "transformer_blocks.3.atten_func.attn.weight", "transformer_blocks.3.atten_func.attn.bias", "transformer_blocks.3.atten_func.proj.weight", "transformer_blocks.3.atten_func.proj.bias", "transformer_blocks.4.atten_func.attn.weight", "transformer_blocks.4.atten_func.attn.bias", "transformer_blocks.4.atten_func.proj.weight", "transformer_blocks.4.atten_func.proj.bias", "transformer_blocks.5.atten_func.attn.weight", "transformer_blocks.5.atten_func.attn.bias", "transformer_blocks.5.atten_func.proj.weight", "transformer_blocks.5.atten_func.proj.bias", "transformer_blocks.6.atten_func.attn.weight", "transformer_blocks.6.atten_func.attn.bias", "transformer_blocks.6.atten_func.proj.weight", "transformer_blocks.6.atten_func.proj.bias", "transformer_blocks.7.atten_func.attn.weight", "transformer_blocks.7.atten_func.attn.bias", "transformer_blocks.7.atten_func.proj.weight", "transformer_blocks.7.atten_func.pro...
Actual vs expected behavior:
The missing keys are the expected keys defined within TransformerBlock(nn.Module), specifically in CausalSelfAttention(nn.Module). However, it looks like within ModelLoader, the _map_attention function is mapping the key to transformer_blocks.{}.atten_func.attn and transformer_blocks.{}.atten_func.proj. Should attn be replaced with qkv_projection and proj be replaced with output_projection?
Description of the bug:
I am following along "Example transformer models (decoder-only LLMs)" and am having issues with checkpoint mapping. I am get this key error: RuntimeError: Error(s) in loading state_dict for Phi2: Missing key(s) in state_dict: "transformer_blocks.0.atten_func.qkv_projection.weight", "transformer_blocks.0.atten_func.qkv_projection.bias", "transformer_blocks.0.atten_func.output_projection.weight", "transformer_blocks.0.atten_func.output_projection.bias", "transformer_blocks.1.atten_func.qkv_projection.weight", "transformer_blocks.1.atten_func.qkv_projection.bias", "transformer_blocks.1.atten_func.output_projection.weight", "transformer_blocks.1.atten_func.output_projection.bias", "transformer_blocks.2.atten_func.qkv_projection.weight", "transformer_blocks.2.atten_func.qkv_projection.bias", "transformer_blocks.2.atten_func.output_projection.weight", "transformer_blocks.2.atten_func.output_projection.bias", "transformer_blocks.3.atten_func.qkv_projection.weight", "transformer_blocks.3.atten_func.qkv_projection.bias", "transformer_blocks.3.atten_func.output_projection.weight", "transformer_blocks.3.atten_func.output_projection.bias", "transformer_blocks.4.atten_func.qkv_projection.weight", "transformer_blocks.4.atten_func.qkv_projection.bias", "transformer_blocks.4.atten_func.output_projection.weight", "transformer_blocks.4.atten_func.output_projection.bias", "transformer_blocks.5.atten_func.qkv_projection.weight", "transformer_blocks.5.atten_func.qkv_projection.bias", "transformer_blocks.5.atten_func.output_projection.weight", "transformer_blocks.5.atten_func.output_projection.bias", "transformer_blocks.6.atten_func.qkv_projection.weight", "transformer_blocks.6.atten_f... Unexpected key(s) in state_dict: "transformer_blocks.0.atten_func.attn.weight", "transformer_blocks.0.atten_func.attn.bias", "transformer_blocks.0.atten_func.proj.weight", "transformer_blocks.0.atten_func.proj.bias", "transformer_blocks.1.atten_func.attn.weight", "transformer_blocks.1.atten_func.attn.bias", "transformer_blocks.1.atten_func.proj.weight", "transformer_blocks.1.atten_func.proj.bias", "transformer_blocks.2.atten_func.attn.weight", "transformer_blocks.2.atten_func.attn.bias", "transformer_blocks.2.atten_func.proj.weight", "transformer_blocks.2.atten_func.proj.bias", "transformer_blocks.3.atten_func.attn.weight", "transformer_blocks.3.atten_func.attn.bias", "transformer_blocks.3.atten_func.proj.weight", "transformer_blocks.3.atten_func.proj.bias", "transformer_blocks.4.atten_func.attn.weight", "transformer_blocks.4.atten_func.attn.bias", "transformer_blocks.4.atten_func.proj.weight", "transformer_blocks.4.atten_func.proj.bias", "transformer_blocks.5.atten_func.attn.weight", "transformer_blocks.5.atten_func.attn.bias", "transformer_blocks.5.atten_func.proj.weight", "transformer_blocks.5.atten_func.proj.bias", "transformer_blocks.6.atten_func.attn.weight", "transformer_blocks.6.atten_func.attn.bias", "transformer_blocks.6.atten_func.proj.weight", "transformer_blocks.6.atten_func.proj.bias", "transformer_blocks.7.atten_func.attn.weight", "transformer_blocks.7.atten_func.attn.bias", "transformer_blocks.7.atten_func.proj.weight", "transformer_blocks.7.atten_func.pro...
Actual vs expected behavior:
The missing keys are the expected keys defined within
TransformerBlock(nn.Module)
, specifically inCausalSelfAttention(nn.Module)
. However, it looks like within ModelLoader, the _map_attention function is mapping the key totransformer_blocks.{}.atten_func.attn
andtransformer_blocks.{}.atten_func.proj
. Shouldattn
be replaced withqkv_projection
andproj
be replaced withoutput_projection
?Any other information you'd like to share?
No response