huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
8.73k stars 1.07k forks source link

KTO train_loss = 0.0 #1722

Open yaoxiao1999 opened 4 weeks ago

yaoxiao1999 commented 4 weeks ago

Hello,

I tried to finetune llama 3 with KTO, but I got zero training loss: {'train_runtime': 422.7951, 'train_samples_per_second': 1.017, 'train_steps_per_second': 0.059, 'train_loss': 0.0, 'epoch': 4.65}.

Here's the run summary given by wandb: total_flos 0.0, train/epoch 4.65116, train/global_step 25, train/grad_norm 0.0, train/kl 0.0, train/learning_rate 0.0, train/loss 0.0, train_loss 0.0, train_runtime 422.7951, train_samples_per_second 1.017, train_steps_per_second 0.059

Previously, I ran into the same problem as this thread, and I added loss.requires_grad_(True) before this line in transformers/trainer.py as suggested in the thread. I'm not sure if this has caused the problem by any chance.

Here's my code:

import pandas as pd
import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, BitsAndBytesConfig
from trl import KTOConfig, KTOTrainer, create_reference_model
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
set_seed(42)

dataset = load_dataset("csv", data_files="data.csv", split='train')
dataset = dataset.class_encode_column("label")
split_dataset = dataset.train_test_split(test_size=0.2, seed=42, stratify_by_column="label")
ds_train = split_dataset['train']
ds_valid = split_dataset['test']

def train_with_kto():
    model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True)
    ref_model = create_reference_model(model)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    training_args = KTOConfig(
        output_dir="output",
        per_device_train_batch_size=2,
        num_train_epochs=5,
        learning_rate=1e-5,
        lr_scheduler_type="cosine",
        gradient_accumulation_steps=8,
        eval_steps=500,
        warmup_ratio=0.1,
        logging_first_step=True,
        max_prompt_length=128,
        max_length=512,
    )

    trainer = KTOTrainer(
        model=model,
        ref_model=ref_model,
        args=training_args,
        tokenizer=tokenizer,
        train_dataset=ds_train,
        eval_dataset=ds_valid,
    )

    trainer.train()
    trainer.push_to_hub()

torch.cuda.empty_cache()
train_with_kto()

I also got a few warnings, although I'm not sure if they're related to the issue:

Warning: The default cache directory for DeepSpeed Triton autotune, /users/.triton/autotune, appears to be on an NFS system. While this is generally acceptable, if you experience slowdowns or hanging when DeepSpeed exits, it is recommended to set the TRITON_CACHE_DIR environment variable to a non-NFS path.
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.3
 [WARNING]  using untested triton version (2.3.1), only 1.0.0 is known to be compatible

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
/users/.conda/envs/kto/lib/python3.9/site-packages/trl/trainer/kto_trainer.py:504: UserWarning: When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig we have set it for you, but you should do it yourself in the future.
  warnings.warn(
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Could not estimate the number of tokens of the input, floating-point operations will not be computed

Thank you!

younesbelkada commented 4 weeks ago

cc @kawine 🙏

kawine commented 3 weeks ago

@yaoxiao1999 do you have this issue if you swap KTOTrainer with DPOTrainer?

Also, can you upload a subset of your dataset to HF? I want to reproduce this issue exactly.

kawine commented 3 weeks ago

also cc'ing @kashif since this seems to be a recent recurring issue

yaoxiao1999 commented 3 weeks ago

@kawine Thank you for getting back! Please find a subset of my data at "yaox99/kto_data"

yaoxiao1999 commented 3 weeks ago

@kawine Hi, I swapped KTO with DPO, and the training loss was no longer zero:

{'loss': 0.6931, 'grad_norm': 96.59649658203125, 'learning_rate': 1e-05, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/rejected': -396.864013671875, 'logps/chosen': -455.818115234375, 'logits/rejected': -3.2099649906158447, 'logits/chosen': -3.25046706199646, 'epoch': 0.36}
{'train_runtime': 73.8844, 'train_samples_per_second': 2.91, 'train_steps_per_second': 0.135, 'train_loss': 0.16300532817840577, 'epoch': 3.64}  
kawine commented 3 weeks ago

@yaoxiao1999 yaox99/kto_data on HF seems to be private -- can you make it public?

kawine commented 3 weeks ago

also, can you tell me how you're launching the job? if it's via accelerate, can you paste the accelerate config that you're using?

yaoxiao1999 commented 3 weeks ago

@kawine My apologies! The dataset is now public.

I got the outputs in the original post using just python kto.py.

If I use accelerate launch, I will get the following errors:

  File "/mnt/parscratch/users/repos/kto/kto.py", line 90, in <module>
    train_with_kto()
  File "/mnt/parscratch/users/repos/kto/kto.py", line 83, in train_with_kto
    trainer.train()
  File "/users/.local/lib/python3.9/site-packages/transformers/trainer.py", line 1876, in train
    return inner_training_loop(
  File "/users/.local/lib/python3.9/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
    self.optimizer.step()
  File "/users/.conda/envs/kto/lib/python3.9/site-packages/accelerate/optimizer.py", line 157, in step
    self.scaler.step(self.optimizer, closure)
  File "/users/.local/lib/python3.9/site-packages/torch/amp/grad_scaler.py", line 449, in step
    assert (
AssertionError: No inf checks were recorded for this optimizer.
Traceback (most recent call last):
  File "/users/.conda/envs/kto/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/users/.conda/envs/kto/lib/python3.9/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
    args.func(args)
  File "/users/.conda/envs/kto/lib/python3.9/site-packages/accelerate/commands/launch.py", line 1088, in launch_command
    multi_gpu_launcher(args)
  File "/users/.conda/envs/kto/lib/python3.9/site-packages/accelerate/commands/launch.py", line 733, in multi_gpu_launcher
    distrib_run.run(args)
  File "/users/.local/lib/python3.9/site-packages/torch/distributed/run.py", line 870, in run
    elastic_launch(
  File "/users/.local/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/users/.local/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 263, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
kto_tune_exp.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 44406)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

This is my accelerate config:

- `Accelerate` version: 0.31.0
- Platform: Linux-3.10.0-1160.105.1.el7.x86_64-x86_64-with-glibc2.17
- `accelerate` bash location: /users/.conda/envs/kto/bin/accelerate
- Python version: 3.9.19
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.3.1+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- System RAM: 503.41 GB
- GPU type: NVIDIA A100-SXM4-80GB
- `Accelerate` default config:
    - compute_environment: LOCAL_MACHINE
    - distributed_type: MULTI_GPU
    - mixed_precision: fp16
    - use_cpu: False
    - debug: True
    - num_processes: 4
    - machine_rank: 0
    - num_machines: 1
    - gpu_ids: all
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - enable_cpu_affinity: False
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []

Many thanks!

kawine commented 3 weeks ago

@yaoxiao1999 i think you're having the same issue with DPO. Note that although the DPO loss is 0.691, the rewards are all zero. This means that the policy and the reference model are producing the exact same probabilities -> log [ policy(y|x)/reference(y|x)] = 0. Thus the DPO loss is - log_e sigmoid(reward(y_w) - reward(y_l)) = - log_e sigmoid(0) = - log_e (1/2) = log_e 2 = 0.691

So the policy model isn't being updated, and this issue is upstream of KTO, something to do with either the Trainer class used by trl or with your installation or package versions. I wasn't able to reproduce the problem on my end, even with your data.

yaoxiao1999 commented 3 weeks ago

@kawine Thank you! Would you mind sharing with me your list of packages so that I can see which versions I should get?

kawine commented 3 weeks ago

@yaoxiao1999 i created a conda env with python 3.12.2 and installed trl directly from the github repo (followed by flashattention2 with pip). list of all packages in my env:

libffi                    3.4.4                h6a678d5_1  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                2.1.5                    pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
multidict                 6.0.5                    pypi_0    pypi
multiprocess              0.70.16                  pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
networkx                  3.3                      pypi_0    pypi
ninja                     1.11.1.1                 pypi_0    pypi
numpy                     1.26.4                   pypi_0    pypi
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu12         8.9.2.26                 pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.5.40                  pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
openssl                   3.0.13               h7f8727e_2  
packaging                 24.1                     pypi_0    pypi
pandas                    2.2.2                    pypi_0    pypi
peft                      0.11.1                   pypi_0    pypi
pip                       24.0            py312h06a4308_0  
platformdirs              4.2.2                    pypi_0    pypi
protobuf                  5.27.1                   pypi_0    pypi
psutil                    5.9.8                    pypi_0    pypi
pyarrow                   16.1.0                   pypi_0    pypi
pyarrow-hotfix            0.6                      pypi_0    pypi
pygments                  2.18.0                   pypi_0    pypi
python                    3.12.2               h996f2a0_0  
python-dateutil           2.9.0.post0              pypi_0    pypi
pytz                      2024.1                   pypi_0    pypi
pyyaml                    6.0.1                    pypi_0    pypi
readline                  8.2                  h5eee18b_0  
regex                     2024.5.15                pypi_0    pypi
requests                  2.32.3                   pypi_0    pypi
rich                      13.7.1                   pypi_0    pypi
safetensors               0.4.3                    pypi_0    pypi
sentry-sdk                2.5.1                    pypi_0    pypi
setproctitle              1.3.3                    pypi_0    pypi
setuptools                69.5.1          py312h06a4308_0  
shtab                     1.7.1                    pypi_0    pypi
six                       1.16.0                   pypi_0    pypi
smmap                     5.0.1                    pypi_0    pypi
sqlite                    3.45.3               h5eee18b_0  
sympy                     1.12.1                   pypi_0    pypi
tk                        8.6.14               h39e8969_0  
tokenizers                0.19.1                   pypi_0    pypi
torch                     2.3.1                    pypi_0    pypi
tqdm                      4.66.4                   pypi_0    pypi
transformers              4.41.2                   pypi_0    pypi
trl                       0.9.5.dev0               pypi_0    pypi
typing-extensions         4.12.2                   pypi_0    pypi
tyro                      0.8.4                    pypi_0    pypi
tzdata                    2024.1                   pypi_0    pypi
urllib3                   2.2.1                    pypi_0    pypi
wandb                     0.17.1                   pypi_0    pypi
wheel                     0.43.0          py312h06a4308_0  
xxhash                    3.4.1                    pypi_0    pypi
xz                        5.4.6                h5eee18b_1  
yarl                      1.9.4                    pypi_0    pypi
zlib                      1.2.13               h5eee18b_1