Open LuisVasquezBSC opened 1 week ago
Thanks for this detailled report. The easiest is probably to remove all columns in dataset.map
dataset.map(..., remove_columns=dataset.column_names)
What do you think? Would you like to make a PR to fix this?
Problem
Some DPO datasets, like Argilla Math, have columns other than
"prompt"
,"chosen"
,"rejected"
, for example"metadata"
,"chosen_rating"
, etc.Currently (
trl==0.12.0
), for this type of dataset, training with KTO (kto.py
) crashes before starting tokenization.Minimally reproducible code
The error trace ends with:
Possible explanation
Looking at the implementation of
unpair_preference_dataset
, the problem seems to be the application of_unpair_row
withdataset.map
.https://github.com/huggingface/trl/blob/623963126be5598bd1eea4ec82b43447fcc11535/trl/data_utils.py#L202-L204 https://github.com/huggingface/trl/blob/623963126be5598bd1eea4ec82b43447fcc11535/trl/data_utils.py#L240
Note that
remove_columns = ["chosen", "rejected"]
means that these columns will be dropped from the original dataset.This is the implementation of
_unpair_row
: https://github.com/huggingface/trl/blob/623963126be5598bd1eea4ec82b43447fcc11535/trl/data_utils.py#L191-L199As we can see,
_unpair_row
outputs new rows whose length is double the length of the original dataset. After applyingdataset.map
, we get a shape mismatch because the undropped columns (those other than"chosen"
and"rejected
") have not been duplicated.This is a known issue with
dataset.map
. More details in this HF tutorial (Ctrl+F "ArrowInvalid").The recommended solutions from HF are to:
Solution 1: Drop all columns irrelevant for KTO
If we only use
unpair_preference_dataset
internally, it is easier to drop all the old columns when applyingdataset.map
. This can be achieved by giving a new argument,remove_columns
, tounpair_preference_dataset
I have implemented this and can now run the
KTOTrainer
without problems.Solution 2: Process all extra columns
The HF tutorial above mentions ways to redefine the function sent to
dataset.map
so that all the columns match. For example, see the progressive evolution of the implementation of`tokenize_and_split
in their examples. I have not tried this approach.