salesforce / jaxformer

Minimal library to train LLMs on TPU in JAX with pjit().
BSD 3-Clause "New" or "Revised" License
277 stars 37 forks source link

[Mismatched_sizes] Got mismatched_size exception when loading the finetuned model #20

Open Jacob-yen opened 1 year ago

Jacob-yen commented 1 year ago

Thank you for sharing the finetuning scripts for CodeGen. However, I encountered a problem when attempting to load the finetuned model using the following code, where pretrain_dir refers to the path of the pytorch_model.bin and config.json.

tokenizer = transformers.AutoTokenizer.from_pretrained("Salesforce/codegen-350M-multi")
model = transformers.CodeGenForCausalLM.from_pretrained(pretrain_dir,config=os.path.join(pretrain_dir,"config.json"))   

An exception was thrown:

Traceback (most recent call last):
  File "/home/User/code-models/get_retrained_model_distribution.py", line 54, in <module>
    tokenizer, retrained_model = model_utils.load_retrained_model(f"output/finetune_codegen/20230316-{args.ds_type}-Epoch30/final_checkpoint-{args.ds_type}-1",model_name)
  File "/home/User/code-models/utils/model_utils.py", line 68, in load_retrained_model
    model = transformers.CodeGenForCausalLM.from_pretrained(pretrain_dir,config=os.path.join(pretrain_dir,"config.json"), local_files_only=True)    
  File "/home/User/.conda/envs/transformers/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2379, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/home/User/.conda/envs/transformers/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2695, in _load_pretrained_model
    raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
RuntimeError: Error(s) in loading state_dict for CodeGenForCausalLM:
        size mismatch for transformer.h.0.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.0.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.0.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.0.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.1.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.1.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.1.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.1.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.2.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.2.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.2.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.2.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.3.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.3.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.3.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.3.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.4.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.4.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.4.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.4.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.5.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.5.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.5.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.5.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.6.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.6.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.6.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.6.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.7.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.7.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.7.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.7.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.8.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.8.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.8.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.8.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.9.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.9.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.9.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.9.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.10.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.10.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.10.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.10.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.11.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.11.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.11.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.11.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.12.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.12.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.12.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.12.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.13.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.13.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.13.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.13.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.14.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.14.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.14.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.14.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.15.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.15.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.15.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.15.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.16.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.16.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.16.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.16.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.17.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.17.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.17.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.17.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.18.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.18.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.18.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.18.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.19.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.19.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.19.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.19.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for lm_head.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([51200, 1024]).
        size mismatch for lm_head.bias: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([51200]).
        You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

Adding ignore_mismatched_sizes=True can avoid the exception but make the model produce non-sense output. I am wondering how to properly load the model finetuned with the deepspeed scripts .

Thanks in advance. : )