Lightning-AI / litgpt

Pretrain, finetune, deploy 20+ LLMs on your own data. Uses state-of-the-art techniques: flash attention, FSDP, 4-bit, LoRA, and more.
https://lightning.ai
Apache License 2.0
6.85k stars 726 forks source link

Failed to load the finetuned model with `AutoModelForCausalLM.from_pretrained(name, state_dict=state_dict)` #1362

Open zhaosheng-thu opened 3 weeks ago

zhaosheng-thu commented 3 weeks ago

I fine-tuned llama3-8b with Lora and followed the tutorial in the repository to convert the final result into model.pth. However, when I try to load the fine-tuned weights into the model using AutoModelForCausalLM.from_pretrained, I am unable to do so correctly. Below is my test:

state_dict = torch.load('out/convert/hf-llama3-instruct-esconv/model.pth')
print("state_dict: ", state_dict)
model = AutoModelForCausalLM.from_pretrained('checkpoints/meta-llama/Meta-Llama-3-8B',
                                  device_map=device_map, torch_dtype=torch.float16, 
                                  state_dict=state_dict)

print("model.weights", model.state_dict())

But I found that the state_dict of torch.load doesn't equal to the model.state_dict(), as shown following: torch.load: c62077774b213ae19704e33b6fb8ee1 model.state_dict() e24af415a1cd401e1743546b0a5314b

I noticed that even though I passed the state_dict, from_pretrained still returns the weights of the model loaded by name. Did I make any mistakes in my code, and how can I solve this? Thanks!

zhaosheng-thu commented 3 weeks ago

I can load the weight using the model.load_state_dict(), and then everything will go smoothly, but I really want to know why from_pretrained(state_dict=state_dict) can't work.

rasbt commented 2 weeks ago

Thanks for raising that. Maybe it's a HF thing. I will have to investigate.

rasbt commented 2 weeks ago

I could not reproduce it for another model yet when I gave it a quick try.

I am not sure if it's related because the differences are so big, but I wonder ~what the precision of the tensors in your current state dict are. Could you print the precision of the state dict, and~ could you also try to load it without torch_dtype=torch.float16?

EDIT: Nevermind, I can see that the precision is bfloat16 in your screenshot.

Screenshot 2024-04-29 at 12 18 03 PM
rasbt commented 2 weeks ago

I tried this also with Llama 3 and it seemed to work fine for me there as well. Here are my steps:

litgpt download --repo_id meta-llama/Meta-Llama-3-8B-Instruct --access_token ...

litgpt finetune \
    --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B-Instruct \
    --out_dir my_llama_model \
    --train.max_steps 1 \
    --eval.max_iter 1

litgpt convert from_litgpt \
    --checkpoint_dir my_llama_model/final \
    --output_dir out/converted_llama_model/

And then in a python session:

Screenshot 2024-04-29 at 1 17 29 PM

and

Screenshot 2024-04-29 at 1 21 32 PM