huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.85k stars 1.25k forks source link

online DPO evaluation #2228

Open woshizouguo opened 1 week ago

woshizouguo commented 1 week ago

System Info

trl=0.11.2

Information

Tasks

Reproduction

for online dpo code

If I add --eval_steps=5 and --eval_strategy=steps, it shows error:

  File "/mnt/task_runtime/trl/examples/scripts/dpo_online.py", line 120, in <module>
    trainer.train()
  File "/opt/miniconda/lib/python3.9/site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File "/opt/miniconda/lib/python3.9/site-packages/transformers/trainer.py", line 2356, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
  File "/mnt/task_runtime/trl/trl/trainer/online_dpo_trainer.py", line 555, in _maybe_log_save_evaluate
    metrics = self._evaluate(trial, ignore_keys_for_eval)
  File "/opt/miniconda/lib/python3.9/site-packages/transformers/trainer.py", line 2761, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/opt/miniconda/lib/python3.9/site-packages/transformers/trainer.py", line 3666, in evaluate
    output = eval_loop(
  File "/opt/miniconda/lib/python3.9/site-packages/transformers/trainer.py", line 3857, in evaluation_loop
    losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
  File "/opt/miniconda/lib/python3.9/site-packages/transformers/trainer.py", line 4085, in prediction_step
    outputs = model(**inputs)
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 186, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 201, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 109, in parallel_apply
    output.reraise()
  File "/opt/miniconda/lib/python3.9/site-packages/torch/_utils.py", line 706, in reraise
    raise exception
TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 84, in _worker
    output = module(*input, **kwargs)
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: forward() got an unexpected keyword argument 'prompt'

Expected behavior

The eval is not using the correct input dataset.

kashif commented 1 week ago

@wenxindongwork I suspect we will need to have our own prediction_step method as we use our own datacollator instead of the default one, and the tests didn't catch this bug since the eval_steps in the tests were > the max_steps so it never ran the evaluation...