huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.17k stars 1.29k forks source link

PPOTrainer breaks when gradient_accumulation_steps > 1 #648

Closed kushalarora closed 1 year ago

kushalarora commented 1 year ago

PPOTrainer throws the following error when passed argument --gradient_accumulation_steps >=2.

$ python trl/examples/scripts/sentiment_tuning.py --gradient_accumulation_steps 2
[2023-08-15 20:35:29,345] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Using /data/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /data/.cache/torch_extensions/py38_cu117/cuda_kernel/build.ninja...
Building extension module cuda_kernel...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module cuda_kernel...
0it [00:00, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
/home/ubuntu/envs/mpror/lib/python3.8/site-packages/transformers/pipelines/text_classification.py:104: UserWarning: `return_all_scores` is now deprecated,  if want a similar functionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.
  warnings.warn(
0it [00:14, ?it/s]
Traceback (most recent call last):
  File "trl/examples/scripts/sentiment_tuning.py", line 203, in <module>
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/ubuntu/wdir/cold_reinforce/trl/trl/trainer/ppo_trainer.py", line 739, in step
    logprobs, logits, vpreds, _ = self.batched_forward_pass(
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/ubuntu/wdir/cold_reinforce/trl/trl/trainer/ppo_trainer.py", line 967, in batched_forward_pass
    torch.cat(all_logprobs),
RuntimeError: torch.cat(): expected a non-empty list of Tensors

TRL version: 0.5.1.dev0

$ pip show trl
Name: trl
Version: 0.5.1.dev0
Summary: A Pytorch implementation of Proximal Policy Optimization for transfomer language models.
Home-page: https://github.com/huggingface/trl
Author: Leandro von Werra
Author-email: leandro.vonwerra@gmail.com
License: Apache 2.0
Location: /home/ubuntu/wdir/cold_reinforce/trl
Requires: accelerate, datasets, numpy, torch, transformers
Required-by:

Related issue here: https://github.com/huggingface/trl/issues/614

kushalarora commented 1 year ago

One further debugging it seems like this happens when self.config.backward_batch_size > batch_size.

This can happen when mini_batch_size * gradient_accumulation_steps > batch_size.

This would result in a scenario where mini_batch_start in this line (the for loop in the snippet below) would be greater than the batch_size resulting in an empty mini_batch_inds splice which would result in a mini_batch_dict with empty entries, i.e, mini_batch_dict['queries'] would be an empty tensor.

                for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
                    mini_batch_end = mini_batch_start + self.config.mini_batch_size
                    mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end]
                    mini_batch_dict = {
                        "logprobs": batch_dict["logprobs"][mini_batch_inds],
                        "values": batch_dict["values"][mini_batch_inds],
                        "masks": batch_dict["masks"][mini_batch_inds],
kushalarora commented 1 year ago

@vwxyzjn Seems like this was introduced in refactoring that was done in #546.

vwxyzjn commented 1 year ago

Ah, thanks for the catch. In this case, we should add a check ensuring self.config.backward_batch_size > batch_size. Gonna prepare a PR tomorrow for this.