CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.5k stars 472 forks source link

PPO training fails with NCCL timeout when running on larger models #373

Closed agave233 closed 1 year ago

agave233 commented 1 year ago

Hello,

I have successfully run the code summarize_rlhf with small SFT and RM models (bloom1b). However, when I try to run the larger model (7B), the timeout error is raised, which is a similar problem as stated in this issue #319. But I can not find a solution.

My environment: trlx version: 0.5.0 accelerate: 0.17.1 torch version: 1.13.1

The error is as follows:

rollout 15 / 16]:  94%|█████████▍| 15/16 [00:25<00:01,  1.81s/it]
[rollout 15 / 16]:  94%|█████████▍| 15/16 [00:27<00:01,  1.81s/it][E ProcessGroupNCCL.cpp:737] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2059, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1800283 milliseconds before timing out.
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /root/paddlejob/workspace/env_run/trlx_ppo_bloom_new.py:197 in <module>      │
│                                                                              │
│   194 │   │   post_summary_dict[val_prompts[i]] = val_summaries[i]           │
│   195 │                                                                      │
│   196 │   reward_fn = create_reward_fn()                                     │
│ ❱ 197 │   trainer = trlx.train(                                              │
│   198 │   │   reward_fn=reward_fn,                                           │
│   199 │   │   prompts=train_prompts,                                         │
│   200 │   │   eval_prompts=val_prompts[0:1000],  # sampling 1000 validation  │
│                                                                              │
│ /root/paddlejob/workspace/env_run/tools/trlx/trlx/trlx.py:123 in train       │
│                                                                              │
│   120 │   eval_pipeline = get_pipeline(config.train.pipeline)(eval_prompts,  │
│   121 │   trainer.add_eval_pipeline(eval_pipeline)                           │
│   122 │                                                                      │
│ ❱ 123 │   trainer.learn()                                                    │
│   124 │   return trainer                                                     │
│   125                                                                        │
│                                                                              │
│ /root/paddlejob/workspace/env_run/tools/trlx/trlx/trainer/accelerate_base_tr │
│ ainer.py:478 in learn                                                        │
│                                                                              │
│   475 │   │   │   │   │   # multiple gradient updates on the same batch of d │
│   476 │   │   │   │   │   # https://arxiv.org/pdf/1707.06347.pdf             │
│   477 │   │   │   │   │   forward_time = time()                              │
│ ❱ 478 │   │   │   │   │   loss, stats = self.loss(batch)                     │
│   479 │   │   │   │   │   forward_time = time() - forward_time               │
│   480 │   │   │   │   │   backward_time = time()                             │
│   481 │   │   │   │   │   self.accelerator.backward(loss)                    │
│                                                                              │
│ /root/paddlejob/workspace/env_run/tools/trlx/trlx/trainer/accelerate_ppo_tra │
│ iner.py:151 in loss                                                          │
│                                                                              │
│   148 │   │   old_rewards = batch.rewards.to(self.accelerator.device)        │
│   149 │   │   response_length = old_rewards.shape[1]                         │
│   150 │   │                                                                  │
│ ❱ 151 │   │   advantages, returns = self.config.method.get_advantages_and_re │
│   152 │   │                                                                  │
│   153 │   │   if self.config.model.model_arch_type == "seq2seq":             │
│   154 │   │   │   input_ids = query_tensors                                  │
│                                                                              │
│ /root/paddlejob/workspace/env_run/tools/trlx/trlx/models/modeling_ppo.py:169 │
│ in get_advantages_and_returns                                                │
│                                                                              │
│    166 │   │   advantages = torch.stack(advantages_reversed[::-1], dim=1)    │
│    167 │   │   returns = advantages + values                                 │
│    168 │   │   if use_whitening:                                             │
│ ❱  169 │   │   │   advantages = whiten(advantages)                           │
│    170 │   │   return advantages.detach(), returns                           │
│    171 │                                                                     │
│    172 │   def loss(                                                         │
│                                                                              │
│ /root/paddlejob/workspace/env_run/tools/trlx/trlx/utils/modeling.py:206 in   │
│ whiten                                                                       │
│                                                                              │
│   203 def whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> tor │
│   204 │   """Whitens values"""                                               │
│   205 │   if distributed and dist.is_initialized():                          │
│ ❱ 206 │   │   mean, var, _ = get_global_statistics(xs)                       │
│   207 │   else:                                                              │
│   208 │   │   var, mean = torch.var_mean(xs)                                 │
│   209                                                                        │
│                                                                              │
│ /root/paddlejob/workspace/env_run/tools/trlx/trlx/utils/modeling.py:193 in   │
│ get_global_statistics                                                        │
│                                                                              │
│   190 │   Computes element-wise mean and variance of the tensor across proce │
│   191 │   """                                                                │
│   192 │   sum_and_count = torch.tensor([xs.sum(), xs.numel()], device=xs.dev │
│ ❱ 193 │   dist.all_reduce(sum_and_count, dist.ReduceOp.SUM)                  │
│   194 │   global_sum, count = sum_and_count                                  │
│   195 │   global_mean = global_sum / count                                   │
│   196                                                                        │
│                                                                              │
│ /opt/conda/envs/py38/lib/python3.8/site-packages/torch/distributed/distribut │
│ ed_c10d.py:1320 in all_reduce                                                │
│                                                                              │
│   1317 │   opts.reduceOp = op                                                │
│   1318 │   if group is None:                                                 │
│   1319 │   │   default_pg = _get_default_group()                             │
│ ❱ 1320 │   │   work = default_pg.allreduce([tensor], opts)                   │
│   1321 │   else:                                                             │
│   1322 │   │   work = group.allreduce([tensor], opts)                        │
│   1323                                                                       │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: NCCL communicator was aborted on rank 2.  Original reason for 
failure was: [Rank 2] Watchdog caught collective operation timeout: 
WorkNCCL(SeqNum=2059, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1800283 
milliseconds before timing out.
[E ProcessGroupNCCL.cpp:414] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:737] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2059, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1801398 milliseconds before timing out.
Traceback (most recent call last):
  File "trlx_ppo_bloom_new.py", line 197, in <module>
    trainer = trlx.train(
  File "/root/paddlejob/workspace/env_run/tools/trlx/trlx/trlx.py", line 123, in train
    trainer.learn()
  File "/root/paddlejob/workspace/env_run/tools/trlx/trlx/trainer/accelerate_base_trainer.py", line 545, in learn
    self.post_epoch_callback()
  File "/root/paddlejob/workspace/env_run/tools/trlx/trlx/trainer/accelerate_ppo_trainer.py", line 232, in post_epoch_callback
    self.make_experience(self.config.method.num_rollouts, self.iter_count)
  File "/root/paddlejob/workspace/env_run/tools/trlx/trlx/trainer/accelerate_ppo_trainer.py", line 302, in make_experience
    padded_prompts = self.accelerator.pad_across_processes(
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/accelerator.py", line 1885, in pad_across_processes
    return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/utils/operations.py", line 413, in pad_across_processes
    return recursively_apply(
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/utils/operations.py", line 101, in recursively_apply
    return func(data, *args, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/utils/operations.py", line 394, in _pad_across_processes
    sizes = gather(size).cpu()
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/utils/operations.py", line 228, in gather
    return _gpu_gather(tensor)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/utils/operations.py", line 208, in _gpu_gather
    return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/utils/operations.py", line 101, in recursively_apply
    return func(data, *args, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/utils/operations.py", line 205, in _gpu_gather_one
    torch.distributed.all_gather(output_tensors, tensor)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2068, in all_gather
    work = default_pg.allgather([tensor_list], [tensor])
RuntimeError: NCCL communicator was aborted on rank 0.  Original reason for failure was: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2059, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1801398 milliseconds before timing out.
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /root/paddlejob/workspace/env_run/trlx_ppo_bloom_new.py:197 in <module>      │
│                                                                              │
│   194 │   │   post_summary_dict[val_prompts[i]] = val_summaries[i]           │
│   195 │                                                                      │
│   196 │   reward_fn = create_reward_fn()                                     │
│ ❱ 197 │   trainer = trlx.train(                                              │
│   198 │   │   reward_fn=reward_fn,                                           │
│   199 │   │   prompts=train_prompts,                                         │
│   200 │   │   eval_prompts=val_prompts[0:1000],  # sampling 1000 validation  │
│                                                                              │
│ /root/paddlejob/workspace/env_run/tools/trlx/trlx/trlx.py:123 in train       │
│                                                                              │
│   120 │   eval_pipeline = get_pipeline(config.train.pipeline)(eval_prompts,  │
│   121 │   trainer.add_eval_pipeline(eval_pipeline)                           │
│   122 │                                                                      │
│ ❱ 123 │   trainer.learn()                                                    │
│   124 │   return trainer                                                     │
│   125                                                                        │
│                                                                              │
│ /root/paddlejob/workspace/env_run/tools/trlx/trlx/trainer/accelerate_base_tr │
│ ainer.py:545 in learn                                                        │
│                                                                              │
│   542 │   │   │   │                                                          │
│   543 │   │   │   │   self.post_backward_callback()                          │
│   544 │   │   │                                                              │
│ ❱ 545 │   │   │   self.post_epoch_callback()                                 │
│   546 │   │   tbar.close()                                                   │
│   547 │                                                                      │
│   548 │   @abstractmethod                                                    │
│                                                                              │
│ /root/paddlejob/workspace/env_run/tools/trlx/trlx/trainer/accelerate_ppo_tra │
│ iner.py:232 in post_epoch_callback                                           │
│                                                                              │
│   229 │   │   │   self.store.export_history(location=self.rollout_logging_di │
│   230 │   │   self.store.clear_history()                                     │
│   231 │   │   # Collect more rollouts for training                           │
│ ❱ 232 │   │   self.make_experience(self.config.method.num_rollouts, self.ite │
│   233 │                                                                      │
│   234 │   def post_backward_callback(self):                                  │
│   235 │   │   self.kl_ctl.update(self.mean_kl.item(), n_steps=self.config.tr │
│                                                                              │
│ /root/paddlejob/workspace/env_run/tools/trlx/trlx/trainer/accelerate_ppo_tra │
│ iner.py:302 in make_experience                                               │
│                                                                              │
│   299 │   │   │   padded_samples = self.accelerator.pad_across_processes(    │
│   300 │   │   │   │   samples, dim=1, pad_index=self.tokenizer.eos_token_id, │
│   301 │   │   │   )                                                          │
│ ❱ 302 │   │   │   padded_prompts = self.accelerator.pad_across_processes(    │
│   303 │   │   │   │   prompt_tensors, dim=1, pad_index=self.tokenizer.eos_to │
│   304 │   │   │   )                                                          │
│   305 │   │   │   gathered_samples = self.accelerator.gather(padded_samples) │
│                                                                              │
│ /opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/accelerator.py:1 │
│ 885 in pad_across_processes                                                  │
│                                                                              │
│   1882 │   │   torch.Size([2])                                               │
│   1883 │   │   ```                                                           │
│   1884 │   │   """                                                           │
│ ❱ 1885 │   │   return pad_across_processes(tensor, dim=dim, pad_index=pad_in │
│   1886 │                                                                     │
│   1887 │   def unwrap_model(self, model, keep_fp32_wrapper: bool = True):    │
│   1888 │   │   """                                                           │
│                                                                              │
│ /opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/utils/operations │
│ .py:413 in pad_across_processes                                              │
│                                                                              │
│   410 │   │   new_tensor[indices] = tensor                                   │
│   411 │   │   return new_tensor                                              │
│   412 │                                                                      │
│ ❱ 413 │   return recursively_apply(                                          │
│   414 │   │   _pad_across_processes, tensor, error_on_other_type=True, dim=d │
│   415 │   )                                                                  │
│   416                                                                        │
│                                                                              │
│ /opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/utils/operations │
│ .py:101 in recursively_apply                                                 │
│                                                                              │
│    98 │   │   │   }                                                          │
│    99 │   │   )                                                              │
│   100 │   elif test_type(data):                                              │
│ ❱ 101 │   │   return func(data, *args, **kwargs)                             │
│   102 │   elif error_on_other_type:                                          │
│   103 │   │   raise TypeError(                                               │
│   104 │   │   │   f"Can't apply {func.__name__} on object of type {type(data │
│                                                                              │
│ /opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/utils/operations │
│ .py:394 in _pad_across_processes                                             │
│                                                                              │
│   391 │   │                                                                  │
│   392 │   │   # Gather all sizes                                             │
│   393 │   │   size = torch.tensor(tensor.shape, device=tensor.device)[None]  │
│ ❱ 394 │   │   sizes = gather(size).cpu()                                     │
│   395 │   │   # Then pad to the maximum size                                 │
│   396 │   │   max_size = max(s[dim] for s in sizes)                          │
│   397 │   │   if max_size == tensor.shape[dim]:                              │
│                                                                              │
│ /opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/utils/operations │
│ .py:228 in gather                                                            │
│                                                                              │
│   225 │   if PartialState().distributed_type == DistributedType.TPU:         │
│   226 │   │   return _tpu_gather(tensor, name="accelerate.utils.gather")     │
│   227 │   elif PartialState().distributed_type in CUDA_DISTRIBUTED_TYPES:    │
│ ❱ 228 │   │   return _gpu_gather(tensor)                                     │
│   229 │   elif PartialState().distributed_type == DistributedType.MULTI_CPU: │
│   230 │   │   return _cpu_gather(tensor)                                     │
│   231 │   else:                                                              │
│                                                                              │
│ /opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/utils/operations │
│ .py:208 in _gpu_gather                                                       │
│                                                                              │
│   205 │   │   torch.distributed.all_gather(output_tensors, tensor)           │
│   206 │   │   return torch.cat(output_tensors, dim=0)                        │
│   207 │                                                                      │
│ ❱ 208 │   return recursively_apply(_gpu_gather_one, tensor, error_on_other_t │
│   209                                                                        │
│   210                                                                        │
│   211 _cpu_gather = _gpu_gather                                              │
│                                                                              │
│ /opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/utils/operations │
│ .py:101 in recursively_apply                                                 │
│                                                                              │
│    98 │   │   │   }                                                          │
│    99 │   │   )                                                              │
│   100 │   elif test_type(data):                                              │
│ ❱ 101 │   │   return func(data, *args, **kwargs)                             │
│   102 │   elif error_on_other_type:                                          │
│   103 │   │   raise TypeError(                                               │
│   104 │   │   │   f"Can't apply {func.__name__} on object of type {type(data │
│                                                                              │
│ /opt/conda/envs/py38/lib/python3.8/site-packages/accelerate/utils/operations │
│ .py:205 in _gpu_gather_one                                                   │
│                                                                              │
│   202 │   │   if tensor.ndim == 0:                                           │
│   203 │   │   │   tensor = tensor.clone()[None]                              │
│   204 │   │   output_tensors = [tensor.clone() for _ in range(torch.distribu │
│ ❱ 205 │   │   torch.distributed.all_gather(output_tensors, tensor)           │
│   206 │   │   return torch.cat(output_tensors, dim=0)                        │
│   207 │                                                                      │
│   208 │   return recursively_apply(_gpu_gather_one, tensor, error_on_other_t │
│                                                                              │
│ /opt/conda/envs/py38/lib/python3.8/site-packages/torch/distributed/distribut │
│ ed_c10d.py:2068 in all_gather                                                │
│                                                                              │
│   2065 │                                                                     │
│   2066 │   if group is None:                                                 │
│   2067 │   │   default_pg = _get_default_group()                             │
│ ❱ 2068 │   │   work = default_pg.allgather([tensor_list], [tensor])          │
│   2069 │   else:                                                             │
│   2070 │   │   work = group.allgather([tensor_list], [tensor])               │
│   2071                                                                       │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: NCCL communicator was aborted on rank 0.  Original reason for 
failure was: [Rank 0] Watchdog caught collective operation timeout: 
WorkNCCL(SeqNum=2059, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1801398 
milliseconds before timing out.
wandb: Waiting for W&B process to finish... (failed 1).
[E ProcessGroupNCCL.cpp:414] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
maxreciprocate commented 1 year ago

Hi, I think I know the solution for this problem, it's only tangentially related to #319, I suspect here's the cause is when one rank completes make_experience ahead of others, due to filtering of empty responses.

agave233 commented 1 year ago

Hi, I think I know the solution for this problem, it's only tangentially related to #319, I suspect here's the cause is when one rank completes make_experience ahead of others, due to filtering of empty responses.

Thanks! It sounds reasonable. So is this a bug or do I need to make some changes in PPO training to solve this problem?

maxreciprocate commented 1 year ago

@agave233 It is a bug, however it occurs under rather stochastic conditions, it will not be triggered if the model doesn't collapse to empty outputs. You could reduce learning rate, or increase batch size, if possible, to remedy that.

agave233 commented 1 year ago

@agave233 It is a bug, however it occurs under rather stochastic conditions, it will not be triggered if the model doesn't collapse to empty outputs. You could reduce learning rate, or increase batch size, if possible, to remedy that.

Thanks for your suggestion. I have tried but it can not work.

I'm a bit curious why a certain process can complete this process first. Is there no inter-process synchronization mechanism in the process of making experiences?

maxreciprocate commented 1 year ago

Hi @agave233, could you post the script you've used and the git commit, so I can reproduce this particular bug? I'm closing in on a fix for it.

Is there no inter-process synchronization mechanism in the process of making experiences?

There was no need for it prior, except apparently the corner cases like yours

agave233 commented 1 year ago

The timeout problem was resolved according to the latest code. Thanks 👍

javirandor commented 1 year ago

I am facing the same problem, but the model does not even start training. It seems to timeout in some reduce operation. I am trying to train the 1B model on --num_processes 3. I am using the latest code. Any idea of what could go wrong?

Trace below

[E ProcessGroupNCCL.cpp:828] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1802544 milliseconds before timing out.
privsec0:2247378:2247601 [0] NCCL INFO comm 0x4e9f0990 rank 2 nranks 3 cudaDev 2 busId 41000 - Abort COMPLETE
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
privsec0:2247377:2247604 [0] NCCL INFO comm 0x4f31ffa0 rank 1 nranks 3 cudaDev 1 busId 23000 - Abort COMPLETE
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
terminate called after throwing an instance of 'terminate called after throwing an instance of 'std::runtime_errorstd::runtime_error'
'
  what():    what():  [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1802503 milliseconds before timing out.[Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1802544 milliseconds before timing out.

[00:13:40] WARNING  Sending process 2247376 closing signal SIGTERM                                                                                          api.py:698
[E ProcessGroupNCCL.cpp:828] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1808417 milliseconds before timing out.
privsec0:2247376:2248280 [0] NCCL INFO [Service thread] Connection closed by localRank 0
privsec0:2247376:2248253 [0] NCCL INFO comm 0x45978320 rank 0 nranks 3 cudaDev 0 busId 1000 - Abort COMPLETE
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1808417 milliseconds before timing out.
maxreciprocate commented 1 year ago

@javirandor Hm, it may be that it hangs on the first barrier (given SeqNum=1) here: https://github.com/CarperAI/trlx/blob/9bc08369ca9ec83342c4d7755205dab1a7723006/trlx/trainer/accelerate_base_trainer.py#L65-L66 Try commenting those lines and give it another attempt. Also have you tried running unmodified exisiting example on your setup, or does it also fail with the same error?