TinyLLaVA / TinyLLaVA_Factory

A Framework of Small-scale Large Multimodal Models
https://arxiv.org/abs/2402.14289
Apache License 2.0
476 stars 39 forks source link

Missing key "lm_head.weight" in GemmaForCausalLM when loading lora finetuned TinyLLaVA-Gemma-SigLIP-2.4B #88

Open Yuki-Kokomi opened 1 week ago

Yuki-Kokomi commented 1 week ago

When attempting to merge LoRA weights into the TinyLLaVA-Gemma-SigLIP-2.4B model, I encountered a RuntimeError due to a missing key lm_head.weight in the GemmaForCausalLM state_dict. The specific error traceback is as follows:

Traceback (most recent call last):
  File "/../TinyLLaVA_Factory/merge_lora_weights.py", line 27, in <module>
    merge_lora(args)
  File "/../TinyLLaVA_Factory/merge_lora_weights.py", line 14, in merge_lora
    model, tokenizer,  image_processor, context_len = load_pretrained_model(args.model_path)
  File "/../TinyLLaVA_Factory/tinyllava/model/load_model.py", line 45, in load_pretrained_model
    model.language_model.load_state_dict(language_model_ckp)
  File "/../miniforge3/envs/tinyllava_factory/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GemmaForCausalLM:
        Missing key(s) in state_dict: "lm_head.weight".

This issue seems similar to vllm-project/vllm#3323. Any insights or solutions to resolve the missing "lm_head.weight" key would be greatly appreciated.

ggcr commented 1 week ago

This just happened to me also. After the pre-training phase I am trying to perform inference however seems there is a miss-match between the saved model on /<model_path>/language_model/pytorch_model.bin

ggcr commented 1 week ago

A quick work around for now would be to follow the solkution made on the issue https://github.com/vllm-project/vllm/issues/3323 linked by @Yuki-Kokomi, by copying the embed_token weight onto the lm_head:

        language_model_ckp_path = os.path.join(model_name_or_path, 'language_model/pytorch_model.bin')
        language_model_ckp = load_base_ckp_for_lora(language_model_ckp_path)

        # This line is what does the trick
        language_model_ckp['lm_head.weight'] = language_model_ckp['model.embed_tokens.weight']

        model.language_model.load_state_dict(language_model_ckp)

However, I am not able to qualitatively validate that this fix works fine.

Yuki-Kokomi commented 1 week ago

Thank you @ggcr for the solution. It works perfectly.

ggcr commented 1 week ago

I can provide a PR by doing this checking if the LLM backbone used is Gemma if the authors want.

YingHuTsing commented 1 week ago

I can provide a PR by doing this checking if the LLM backbone used is Gemma if the authors want.

Hi, we do encourage you to initiate a PR!