google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
228 stars 26 forks source link

Possible key error in ModelLoader #6

Closed kaushiksiva07 closed 2 months ago

kaushiksiva07 commented 2 months ago

Description of the bug:

I am following along "Example transformer models (decoder-only LLMs)" and am having issues with checkpoint mapping. I am get this key error: RuntimeError: Error(s) in loading state_dict for Phi2: Missing key(s) in state_dict: "transformer_blocks.0.atten_func.qkv_projection.weight", "transformer_blocks.0.atten_func.qkv_projection.bias", "transformer_blocks.0.atten_func.output_projection.weight", "transformer_blocks.0.atten_func.output_projection.bias", "transformer_blocks.1.atten_func.qkv_projection.weight", "transformer_blocks.1.atten_func.qkv_projection.bias", "transformer_blocks.1.atten_func.output_projection.weight", "transformer_blocks.1.atten_func.output_projection.bias", "transformer_blocks.2.atten_func.qkv_projection.weight", "transformer_blocks.2.atten_func.qkv_projection.bias", "transformer_blocks.2.atten_func.output_projection.weight", "transformer_blocks.2.atten_func.output_projection.bias", "transformer_blocks.3.atten_func.qkv_projection.weight", "transformer_blocks.3.atten_func.qkv_projection.bias", "transformer_blocks.3.atten_func.output_projection.weight", "transformer_blocks.3.atten_func.output_projection.bias", "transformer_blocks.4.atten_func.qkv_projection.weight", "transformer_blocks.4.atten_func.qkv_projection.bias", "transformer_blocks.4.atten_func.output_projection.weight", "transformer_blocks.4.atten_func.output_projection.bias", "transformer_blocks.5.atten_func.qkv_projection.weight", "transformer_blocks.5.atten_func.qkv_projection.bias", "transformer_blocks.5.atten_func.output_projection.weight", "transformer_blocks.5.atten_func.output_projection.bias", "transformer_blocks.6.atten_func.qkv_projection.weight", "transformer_blocks.6.atten_f... Unexpected key(s) in state_dict: "transformer_blocks.0.atten_func.attn.weight", "transformer_blocks.0.atten_func.attn.bias", "transformer_blocks.0.atten_func.proj.weight", "transformer_blocks.0.atten_func.proj.bias", "transformer_blocks.1.atten_func.attn.weight", "transformer_blocks.1.atten_func.attn.bias", "transformer_blocks.1.atten_func.proj.weight", "transformer_blocks.1.atten_func.proj.bias", "transformer_blocks.2.atten_func.attn.weight", "transformer_blocks.2.atten_func.attn.bias", "transformer_blocks.2.atten_func.proj.weight", "transformer_blocks.2.atten_func.proj.bias", "transformer_blocks.3.atten_func.attn.weight", "transformer_blocks.3.atten_func.attn.bias", "transformer_blocks.3.atten_func.proj.weight", "transformer_blocks.3.atten_func.proj.bias", "transformer_blocks.4.atten_func.attn.weight", "transformer_blocks.4.atten_func.attn.bias", "transformer_blocks.4.atten_func.proj.weight", "transformer_blocks.4.atten_func.proj.bias", "transformer_blocks.5.atten_func.attn.weight", "transformer_blocks.5.atten_func.attn.bias", "transformer_blocks.5.atten_func.proj.weight", "transformer_blocks.5.atten_func.proj.bias", "transformer_blocks.6.atten_func.attn.weight", "transformer_blocks.6.atten_func.attn.bias", "transformer_blocks.6.atten_func.proj.weight", "transformer_blocks.6.atten_func.proj.bias", "transformer_blocks.7.atten_func.attn.weight", "transformer_blocks.7.atten_func.attn.bias", "transformer_blocks.7.atten_func.proj.weight", "transformer_blocks.7.atten_func.pro...

Actual vs expected behavior:

The missing keys are the expected keys defined within TransformerBlock(nn.Module), specifically in CausalSelfAttention(nn.Module). However, it looks like within ModelLoader, the _map_attention function is mapping the key to transformer_blocks.{}.atten_func.attn and transformer_blocks.{}.atten_func.proj. Should attn be replaced with qkv_projection and proj be replaced with output_projection?

Any other information you'd like to share?

No response