pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.99k stars 360 forks source link

Can I Finetune Llama3 Without Creating CustomDataset Function? #1362

Open JinchuLi2002 opened 1 month ago

JinchuLi2002 commented 1 month ago

Hello,

I have a dataset in .jsonl format of the following format: {"messages": [{"role": "system", "content": "some system msg"}, {"role": "user", "content": "some user input"}, {"role": "assistant", "content": "some expected output"}]}.

I see there is an option to specify conversation_style=openai and source=json in the config files. However the tutorial for Llama3 finetuning did not mention this.

I was able to get finetuning running, but the result wasn't very good and I was just wondering if my workflow is correct at all, (also I couldn't seem to input the system prompt in inference testing, any ideas?)

Thank you in advance

RdoubleA commented 1 month ago

Hi @JinchuLi2002, what you described seems correct. If you don't mind sharing how you defined the dataset in your config, I can provide more guidance.

We've just added support for adding a system prompt in #1366, will soon update inference as well.

JinchuLi2002 commented 1 month ago

@RdoubleA Thank you for the reply, below is how I defined my dataset.

# Dataset and Sampler
dataset:
  _component_: torchtune.datasets.chat_dataset
  source: json
  train_on_input: True
  conversation_style: openai
  chat_format: torchtune.data.ChatMLFormat
  max_seq_len: 8192
  data_files: /global/Jinchu/datasets/test.jsonl
  split: train
seed: null
shuffle: True
batch_size: 4

For now, the dataset is essentially a single row repeated 1000 times, I was hoping it would make the model generate the same result after fine-tuning: {"messages": [{"role": "user", "content": "what day is today"}, {"role": "assistant", "content": "aug 11 2002, Monday"}]}, here I removed the system prompt since you mentioned that system prompt is not supported in inference yet.

However, as I ran the generation recipe with the following config,

model:
  _component_: torchtune.models.llama3.llama3_8b

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: /global/Jinchu/models/Meta-Llama-3-8B-Instruct/
  checkpoint_files: [
    meta_model_9.pt,
  ]
  output_dir: /global/Jinchu/models/Meta-Llama-3-8B-Instruct
  model_type: LLAMA3

device: cuda
dtype: bf16

seed: 42

tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /global/Jinchu/models/Meta-Llama-3-8B-Instruct/original/tokenizer.model

prompt: what day is today?

chat_format: torchtune.data.ChatMLFormat
instruct_template: null
max_new_tokens: 30
temperature: 0 
top_k: 300
enable_kv_cache: True

quantizer: null

The result I got was

...
INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils.logging:what day is today? - 2023-03-24
what day is today? - 2023-03-25
what day is today? - 
INFO:torchtune.utils.logging:Time for inference: 1.84 sec total, 16.28 tokens/sec
...
felipemello1 commented 1 month ago

@JinchuLi2002 , nice experiment! how did the loss look like in training? Did it overfit to the dataset? Maybe LR needs to be higher.

RdoubleA commented 1 month ago

Your dataset looks correctly configured. It may be difficult to overtrain the base model to overfit to a specific response and undo the base model's learned behaviors, but you could try different experiments as @felipemello1 suggested. The only other thing I would recommend is using a system prompt, which we'll need to add for the generate recipe.

Are you simply trying to test if the model is learning from your data? If so, I would recommend using other metrics, such as the loss curve during training and other eval metrics to properly assess this.

JinchuLi2002 commented 1 month ago

@RdoubleA @felipemello1 Below is the loss curve (10 epochs).

image

I'll keep experimenting with parameters like LR (here I used 3*10^-5) like yall suggested, thanks!

felipemello1 commented 1 month ago

That loss makes sense to me. And I also see that you have temperature = 0.

In your eval, does the model get it wrong a bunch of times, or just a few? What I find odd is that this type of problem shouldnt even require finetuning, and having examples in the context should be enough.

Alternative: