KTO train_loss = 0.0 #1722

Open yaoxiao1999 opened 4 weeks ago

yaoxiao1999 commented 4 weeks ago


I tried to finetune llama 3 with KTO, but I got zero training loss: {'train_runtime': 422.7951, 'train_samples_per_second': 1.017, 'train_steps_per_second': 0.059, 'train_loss': 0.0, 'epoch': 4.65}.

Here's the run summary given by wandb: total_flos 0.0, train/epoch 4.65116, train/global_step 25, train/grad_norm 0.0, train/kl 0.0, train/learning_rate 0.0, train/loss 0.0, train_loss 0.0, train_runtime 422.7951, train_samples_per_second 1.017, train_steps_per_second 0.059

Previously, I ran into the same problem as this thread, and I added loss.requires_grad_(True) before this line in transformers/ as suggested in the thread. I'm not sure if this has caused the problem by any chance.

Here's my code:

import pandas as pd
import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, BitsAndBytesConfig
from trl import KTOConfig, KTOTrainer, create_reference_model
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

dataset = load_dataset("csv", data_files="data.csv", split='train')
dataset = dataset.class_encode_column("label")
split_dataset = dataset.train_test_split(test_size=0.2, seed=42, stratify_by_column="label")
ds_train = split_dataset['train']
ds_valid = split_dataset['test']

def train_with_kto():
    model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True)
    ref_model = create_reference_model(model)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    training_args = KTOConfig(

    trainer = KTOTrainer(



I also got a few warnings, although I'm not sure if they're related to the issue:

Warning: The default cache directory for DeepSpeed Triton autotune, /users/.triton/autotune, appears to be on an NFS system. While this is generally acceptable, if you experience slowdowns or hanging when DeepSpeed exits, it is recommended to set the TRITON_CACHE_DIR environment variable to a non-NFS path.
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.3
 [WARNING]  using untested triton version (2.3.1), only 1.0.0 is known to be compatible

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
/users/.conda/envs/kto/lib/python3.9/site-packages/trl/trainer/ UserWarning: When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig we have set it for you, but you should do it yourself in the future.
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Could not estimate the number of tokens of the input, floating-point operations will not be computed

Thank you!

younesbelkada commented 4 weeks ago

cc @kawine 🙏

kawine commented 3 weeks ago

@yaoxiao1999 do you have this issue if you swap KTOTrainer with DPOTrainer?

Also, can you upload a subset of your dataset to HF? I want to reproduce this issue exactly.

kawine commented 3 weeks ago

also cc'ing @kashif since this seems to be a recent recurring issue

yaoxiao1999 commented 3 weeks ago

@kawine Thank you for getting back! Please find a subset of my data at "yaox99/kto_data"

yaoxiao1999 commented 3 weeks ago

@kawine Hi, I swapped KTO with DPO, and the training loss was no longer zero:

{'loss': 0.6931, 'grad_norm': 96.59649658203125, 'learning_rate': 1e-05, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/rejected': -396.864013671875, 'logps/chosen': -455.818115234375, 'logits/rejected': -3.2099649906158447, 'logits/chosen': -3.25046706199646, 'epoch': 0.36}
{'train_runtime': 73.8844, 'train_samples_per_second': 2.91, 'train_steps_per_second': 0.135, 'train_loss': 0.16300532817840577, 'epoch': 3.64}  
kawine commented 3 weeks ago

@yaoxiao1999 yaox99/kto_data on HF seems to be private -- can you make it public?

kawine commented 3 weeks ago

also, can you tell me how you're launching the job? if it's via accelerate, can you paste the accelerate config that you're using?

yaoxiao1999 commented 3 weeks ago

@kawine My apologies! The dataset is now public.

I got the outputs in the original post using just python

If I use accelerate launch, I will get the following errors:

  File "/mnt/parscratch/users/repos/kto/", line 90, in <module>
  File "/mnt/parscratch/users/repos/kto/", line 83, in train_with_kto
  File "/users/.local/lib/python3.9/site-packages/transformers/", line 1876, in train
    return inner_training_loop(
  File "/users/.local/lib/python3.9/site-packages/transformers/", line 2279, in _inner_training_loop
  File "/users/.conda/envs/kto/lib/python3.9/site-packages/accelerate/", line 157, in step
    self.scaler.step(self.optimizer, closure)
  File "/users/.local/lib/python3.9/site-packages/torch/amp/", line 449, in step
    assert (
AssertionError: No inf checks were recorded for this optimizer.
Traceback (most recent call last):
  File "/users/.conda/envs/kto/bin/accelerate", line 8, in <module>
  File "/users/.conda/envs/kto/lib/python3.9/site-packages/accelerate/commands/", line 48, in main
  File "/users/.conda/envs/kto/lib/python3.9/site-packages/accelerate/commands/", line 1088, in launch_command
  File "/users/.conda/envs/kto/lib/python3.9/site-packages/accelerate/commands/", line 733, in multi_gpu_launcher
  File "/users/.local/lib/python3.9/site-packages/torch/distributed/", line 870, in run
  File "/users/.local/lib/python3.9/site-packages/torch/distributed/launcher/", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/users/.local/lib/python3.9/site-packages/torch/distributed/launcher/", line 263, in launch_agent
    raise ChildFailedError(
============================================================ FAILED
Root Cause (first observed failure):
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 44406)
  error_file: <N/A>
  traceback : To enable traceback see:

This is my accelerate config:

- `Accelerate` version: 0.31.0
- Platform: Linux-3.10.0-1160.105.1.el7.x86_64-x86_64-with-glibc2.17
- `accelerate` bash location: /users/.conda/envs/kto/bin/accelerate
- Python version: 3.9.19
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.3.1+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- System RAM: 503.41 GB
- GPU type: NVIDIA A100-SXM4-80GB
- `Accelerate` default config:
    - compute_environment: LOCAL_MACHINE
    - distributed_type: MULTI_GPU
    - mixed_precision: fp16
    - use_cpu: False
    - debug: True
    - num_processes: 4
    - machine_rank: 0
    - num_machines: 1
    - gpu_ids: all
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - enable_cpu_affinity: False
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []

Many thanks!

kawine commented 3 weeks ago

@yaoxiao1999 i think you're having the same issue with DPO. Note that although the DPO loss is 0.691, the rewards are all zero. This means that the policy and the reference model are producing the exact same probabilities -> log [ policy(y|x)/reference(y|x)] = 0. Thus the DPO loss is - log_e sigmoid(reward(y_w) - reward(y_l)) = - log_e sigmoid(0) = - log_e (1/2) = log_e 2 = 0.691

So the policy model isn't being updated, and this issue is upstream of KTO, something to do with either the Trainer class used by trl or with your installation or package versions. I wasn't able to reproduce the problem on my end, even with your data.

yaoxiao1999 commented 3 weeks ago

@kawine Thank you! Would you mind sharing with me your list of packages so that I can see which versions I should get?

kawine commented 3 weeks ago

@yaoxiao1999 i created a conda env with python 3.12.2 and installed trl directly from the github repo (followed by flashattention2 with pip). list of all packages in my env:

