pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.33k stars 437 forks source link

Loss not going down for fine-tuning Llama3-8B on C4 #1526

Open andrewor14 opened 2 months ago

andrewor14 commented 2 months ago

I'm fine-tuning Llama3-8B on the C4 dataset (en subset) for 2000 steps using the full_finetune_distributed recipe. I find that the loss did not go down at all and the quantized accuracy is very low. The exact same workflow used to work in late June (can't find the exact commit now) but seems to be broken in recent commits (8/13/24).

Eval quantized fine-tuned checkpoint:

|  Tasks  |Version|Filter|n-shot|    Metric     |Value |   |Stderr|
|---------|------:|------|-----:|---------------|-----:|---|------|
|wikitext |      2|none  |     0|word_perplexity|   NaN|±  |N/A   |
|         |       |none  |     0|byte_perplexity|   NaN|±  |N/A   |
|         |       |none  |     0|bits_per_byte  |   NaN|±  |N/A   |
|hellaswag|      1|none  |     0|acc            |0.2504|±  |0.0043|
|         |       |none  |     0|acc_norm       |0.2504|±  |0.0043|

Eval quantized original checkpoint (no fine-tuning):

|  Tasks  |Version|Filter|n-shot|    Metric     | Value |   |Stderr|
|---------|------:|------|-----:|---------------|------:|---|------|
|wikitext |      2|none  |     0|word_perplexity|12.3473|±  |N/A   |
|         |       |none  |     0|byte_perplexity| 1.6000|±  |N/A   |
|         |       |none  |     0|bits_per_byte  | 0.6781|±  |N/A   |
|hellaswag|      1|none  |     0|acc            | 0.5596|±  |0.0050|
|         |       |none  |     0|acc_norm       | 0.7421|±  |0.0044|

Some relevant configs:

dataset:
  _component_: torchtune.datasets.text_completion_dataset
  source: allenai/c4
  max_seq_len: 8192
  column: text
  name: en
  split: train
seed: null
shuffle: True

batch_size: 2
epochs: 1
max_steps_per_epoch: 2000

optimizer:
  _component_: torch.optim.AdamW
  lr: 2e-5
  foreach: False
gau-nernst commented 2 months ago

I had problems fine-tuning Llama3.1 with torchtune too (i.e. fine-tuned model performs worse than original). I think one problem is that the Llama3 recipes in torchtune are using the instruct version, which can be difficult to fine-tune.

The exact same workflow used to work in late June (can't find the exact commit now) but seems to be broken in recent commits (8/13/24).

Do you rmb if you used Llama2 or Llama3? Fine-tuning base Llama2 (non-instruct version) is fine for me.

andrewor14 commented 2 months ago

Do you rmb if you used Llama2 or Llama3?

This was Llama3-8B