huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.23k stars 1.3k forks source link

`examples/scripts/kto.py` does not work #2238

Closed reihig-ut closed 1 month ago

reihig-ut commented 1 month ago

System Info

Information

Tasks

Reproduction

The script is the same as written in examples/scripts/kto.py.

conda create -n run_kto python=3.11
conda activate run_kto
git clone https://github.com/huggingface/trl.git
cd trl/
pip install .
python examples/scripts/kto.py     --dataset_name trl-lib/kto-mix-14k     --model_name_or_path=trl-lib/qwen1.5-1.8b-sft     --per_device_train_batch_size 16     --num_train_epochs 1     --learning_rate 5e-7     --lr_scheduler_type=cosine     --gradient_accumulation_steps 1     --logging_steps 10     --eval_steps 500     --output_dir=kto-aligned-model     --warmup_ratio 0.1     --report_to wandb     --bf16     --logging_first_step --trust_remote_code

Then it says

Map:   0%|                                                                                                                          | 0/13500 [00:00<?, ? examples/s]
Traceback (most recent call last):
  File "/home/hoge/project/test/trl/examples/scripts/kto.py", line 114, in <module>
    dataset = dataset.map(format_dataset, num_proc=training_args.dataset_num_proc)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hoge/miniconda3/envs/run_kto/lib/python3.11/site-packages/datasets/dataset_dict.py", line 866, in map
    {
  File "/home/hoge/miniconda3/envs/run_kto/lib/python3.11/site-packages/datasets/dataset_dict.py", line 867, in <dictcomp>
    k: dataset.map(
       ^^^^^^^^^^^^
  File "/home/hoge/miniconda3/envs/run_kto/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 560, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hoge/miniconda3/envs/run_kto/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3035, in map
    for rank, done, content in Dataset._map_single(**dataset_kwargs):
  File "/home/hoge/miniconda3/envs/run_kto/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3408, in _map_single
    example = apply_function_on_filtered_inputs(example, i, offset=offset)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hoge/miniconda3/envs/run_kto/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3300, in apply_function_on_filtered_inputs
    processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hoge/project/test/trl/examples/scripts/kto.py", line 107, in format_dataset
    example["prompt"] = tokenizer.apply_chat_template(example["completion"][:-1], tokenize=False)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hoge/miniconda3/envs/run_kto/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 1812, in apply_chat_template
    isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
               ~~~~~~~~~~~~^^^
IndexError: list index out of range

I think format_dataset function in examples/scripts/kto.py is not sufficient to support the default dataset trl-lib/kto-mix-14k, which possibly contains some examples with len(example["completion"]) == 1.

Expected behavior

The script completes training.

qgallouedec commented 1 month ago

Thanks for pointing this out, #2248 will fix it

reihig-ut commented 1 month ago

Thank you for your PR!

I retried the reproduction process on branch kto-conv-data-support, I got this error:

/home/hoge/miniconda3/envs/run_kto/lib/python3.11/site-packages/trl/trainer/kto_trainer.py:479: UserWarning: When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init it will be set to `512` by default, but you should do it yourself in the future.
  warnings.warn(
/home/hoge/miniconda3/envs/run_kto/lib/python3.11/site-packages/trl/trainer/kto_trainer.py:489: UserWarning: When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init it will be set to `128` by default, but you should do it yourself in the future.
  warnings.warn(
/home/hoge/miniconda3/envs/run_kto/lib/python3.11/site-packages/trl/trainer/kto_trainer.py:519: UserWarning: When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig we have set it for you, but you should do it yourself in the future.
  warnings.warn(
Traceback (most recent call last):
  File "/home/hoge/project/test/trl/examples/scripts/kto.py", line 97, in <module>
    trainer = KTOTrainer(
              ^^^^^^^^^^^
  File "/home/hoge/miniconda3/envs/run_kto/lib/python3.11/site-packages/trl/trainer/kto_trainer.py", line 721, in __init__
    super().__init__(
TypeError: Trainer.__init__() got an unexpected keyword argument 'processing_class'
benchay1999 commented 1 month ago

Changing processing_class to tokenizer worked for me.

kashif commented 1 month ago

should be fixed now in main with latest transformer release

chenyang399 commented 3 weeks ago

How much memory it needs to run the KTO script ? is using the KTO script must have a GPU memory more than 24G? i use the 4090 with 24G memory failed.