huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.18k stars 1.29k forks source link

KTO: `unpair_preference_dataset` does not work for datasets with additional columns #2351

Open LuisVasquezBSC opened 1 week ago

LuisVasquezBSC commented 1 week ago

Problem

Some DPO datasets, like Argilla Math, have columns other than "prompt", "chosen", "rejected", for example "metadata", "chosen_rating", etc.

Currently (trl==0.12.0), for this type of dataset, training with KTO (kto.py) crashes before starting tokenization.

Minimally reproducible code

from datasets import load_dataset
from trl.data_utils import unpair_preference_dataset

dataset = load_dataset("RLHFlow/Argilla-Math-DPO-standard") 
unpair_preference_dataset(dataset)

The error trace ends with:

pyarrow.lib.ArrowInvalid: Column 3 named completion expected length 1000 but got length 2000

Possible explanation

Looking at the implementation of unpair_preference_dataset, the problem seems to be the application of _unpair_row with dataset.map.

https://github.com/huggingface/trl/blob/623963126be5598bd1eea4ec82b43447fcc11535/trl/data_utils.py#L202-L204 https://github.com/huggingface/trl/blob/623963126be5598bd1eea4ec82b43447fcc11535/trl/data_utils.py#L240

Note that remove_columns = ["chosen", "rejected"] means that these columns will be dropped from the original dataset.

This is the implementation of _unpair_row: https://github.com/huggingface/trl/blob/623963126be5598bd1eea4ec82b43447fcc11535/trl/data_utils.py#L191-L199

As we can see, _unpair_row outputs new rows whose length is double the length of the original dataset. After applying dataset.map, we get a shape mismatch because the undropped columns (those other than "chosen" and "rejected") have not been duplicated.

This is a known issue with dataset.map. More details in this HF tutorial (Ctrl+F "ArrowInvalid").

The recommended solutions from HF are to:

  1. drop all columns not necessary for your script, or
  2. preprocess the extra columns so that the shapes match

Solution 1: Drop all columns irrelevant for KTO

If we only use unpair_preference_dataset internally, it is easier to drop all the old columns when applying dataset.map. This can be achieved by giving a new argument, remove_columns, to unpair_preference_dataset

def maybe_unpair_preference_dataset(
    dataset: DatasetType, num_proc: Optional[int] = None, desc: Optional[str] = None
) -> DatasetType:
    if isinstance(dataset, DatasetDict):
        column_names = dataset[list(dataset.keys())[0]].column_names
    else:
        column_names = dataset.column_names
    if "chosen" in column_names and "rejected" in column_names:
-        return unpair_preference_dataset(dataset, num_proc=num_proc, desc=desc)
+        return unpair_preference_dataset(dataset, num_proc=num_proc, desc=desc, remove_columns=column_names)
    else:
        return dataset

def unpair_preference_dataset(
-    dataset: DatasetType, num_proc: Optional[int] = None, desc: Optional[str] = None
+    dataset: DatasetType, num_proc: Optional[int] = None, desc: Optional[str] = None, remove_columns : Optional[list] = None
) -> DatasetType:
+ """DOCSTRING for `remove_columns` argument"""    
-    return dataset.map(_unpair_row, batched=True, remove_columns=["chosen", "rejected"], num_proc=num_proc, desc=desc)
+    return dataset.map(_unpair_row, batched=True, remove_columns=remove_columns, num_proc=num_proc, desc=desc)

I have implemented this and can now run the KTOTrainer without problems.

Solution 2: Process all extra columns

The HF tutorial above mentions ways to redefine the function sent to dataset.map so that all the columns match. For example, see the progressive evolution of the implementation of `tokenize_and_split in their examples. I have not tried this approach.

qgallouedec commented 2 days ago

Thanks for this detailled report. The easiest is probably to remove all columns in dataset.map

dataset.map(..., remove_columns=dataset.column_names)

What do you think? Would you like to make a PR to fix this?