Lightning-AI / litgpt

20+ high-performance LLMs with recipes to pretrain, finetune and deploy at scale.
https://lightning.ai
Apache License 2.0
9.84k stars 985 forks source link

Significantly different results with inference using a saved checkpoint v/s inferencing during fine-tuning #686

Open madhurapande19 opened 10 months ago

madhurapande19 commented 10 months ago

I am using lora for fine-tuning Llama-2-7b-chat-hf with my custom dataset. This is the command I am running: !CUDA_VISIBLE_DEVICES=2,3 python finetune/lora.py --data_dir '../new_test_data_llm' --checkpoint_dir '../checkpoint/Llama-2-7b-chat-hf/' --out_dir '../output_chat_model_temp'

Loss is decreasing steadily:

iter 91 step 92: loss 0.3682, iter time: 2216.19ms (optimizer.step)
iter 92 step 93: loss 0.5185, iter time: 2389.64ms (optimizer.step)
iter 93 step 94: loss 0.3875, iter time: 2417.41ms (optimizer.step)
iter 94 step 95: loss 0.1968, iter time: 2399.70ms (optimizer.step)
iter 95 step 96: loss 0.2645, iter time: 2209.37ms (optimizer.step)
iter 96 step 97: loss 0.3107, iter time: 2206.34ms (optimizer.step)
iter 97 step 98: loss 0.3512, iter time: 2174.37ms (optimizer.step)
iter 98 step 99: loss 0.3052, iter time: 2338.01ms (optimizer.step)
iter 99 step 100: loss 0.3454, iter time: 2274.14ms (optimizer.step)

and validation step that runs after 100 iters also shows that the response is shaping up correctly (at least generating the right response format in terms of number of words etc). This is the output from inferencing after 100 iterations when getting triggered from fine-tuning script (finetune/lora.py).

iter 98 step 99: loss 0.3051, iter time: 2506.50ms (optimizer.step)
iter 99 step 100: loss 0.3453, iter time: 2524.91ms (optimizer.step)
Validating ...
Rephrase the ... <full prompt> ... request.

### Instruction:
Rephrase ... <full prompt> ... crisp response.

### Input:
0 year girls frock

### Response:8 year girl clothess

But later when I run seperate inference call with explicit loading of the checkpoint (trained from above) with this command: !python generate/lora.py --checkpoint_dir '../LLM/checkpoint/Llama-2-7b-chat-hf/' --lora_path '../output_chat_model_temp/lit_model_lora_finetuned.pth' --prompt 'Rephrase ... <same prompt> ... crisp response.' --input '0 year girls frock'

I get this weird output:

Time to instantiate model: 3.37 seconds.
Time to load the model weights: 8.97 seconds.
Number of total parameters: 6,742,609,920
0 years girls females andmature 3 years girls baby carriage andmature 5 years girls females andmature 7 years girls females and 9 years girls females andmature 11 years girls females, etc.
0 years girls females and 0 years girls females 0 years girls females, etc.
0 years girls females, etc.

### Instruction:
Repeat the story2018 girls females, etc.
0 years girls

Time for inference: 4.89 sec total, 20.46 tokens/sec
Memory used: 13.82 GB

Why is this happening? Can someone please point it if I am doing something wrong here. I think the training is going on fine (evident with decreasing loss and intermediate correct responses in validation steps), but final inferencing output is so off, not able to understand why.

Jeronymous commented 10 months ago

Finetuning with multi-GPU might just be bugged, as reported here: https://github.com/Lightning-AI/lit-gpt/issues/652 Unfortunately, I couldn't have any feedback on this issue.

Also your training losses seem to be quite low. If you faced the same problem as we did (training starting from random weights instead of the actual foundation models), I would expect losses to be high after 100 steps. Unless you train on a very small dataset, that contains a similar prompt as the one you use to validate the model.

If you can afford it, you can maybe try single GPU training (making sure that it amounts to fabric.init_module(empty_init=False) in your version of the code), and see if it solves your problem...

madhurapande19 commented 10 months ago

@Jeronymous Thanks for your inputs. You are absolutely right. Tried fine tuning on bigger GPU and single GPU fine-tuning followed by inferencing is quite improved.