Open madhurapande19 opened 10 months ago
Finetuning with multi-GPU might just be bugged, as reported here: https://github.com/Lightning-AI/lit-gpt/issues/652 Unfortunately, I couldn't have any feedback on this issue.
Also your training losses seem to be quite low. If you faced the same problem as we did (training starting from random weights instead of the actual foundation models), I would expect losses to be high after 100 steps. Unless you train on a very small dataset, that contains a similar prompt as the one you use to validate the model.
If you can afford it, you can maybe try single GPU training (making sure that it amounts to fabric.init_module(empty_init=False)
in your version of the code), and see if it solves your problem...
@Jeronymous Thanks for your inputs. You are absolutely right. Tried fine tuning on bigger GPU and single GPU fine-tuning followed by inferencing is quite improved.
I am using lora for fine-tuning Llama-2-7b-chat-hf with my custom dataset. This is the command I am running:
!CUDA_VISIBLE_DEVICES=2,3 python finetune/lora.py --data_dir '../new_test_data_llm' --checkpoint_dir '../checkpoint/Llama-2-7b-chat-hf/' --out_dir '../output_chat_model_temp'
Loss is decreasing steadily:
and validation step that runs after 100 iters also shows that the response is shaping up correctly (at least generating the right response format in terms of number of words etc). This is the output from inferencing after 100 iterations when getting triggered from fine-tuning script (finetune/lora.py).
But later when I run seperate inference call with explicit loading of the checkpoint (trained from above) with this command:
!python generate/lora.py --checkpoint_dir '../LLM/checkpoint/Llama-2-7b-chat-hf/' --lora_path '../output_chat_model_temp/lit_model_lora_finetuned.pth' --prompt 'Rephrase ... <same prompt> ... crisp response.' --input '0 year girls frock'
I get this weird output:
Why is this happening? Can someone please point it if I am doing something wrong here. I think the training is going on fine (evident with decreasing loss and intermediate correct responses in validation steps), but final inferencing output is so off, not able to understand why.