NVIDIA / nccl

Optimized primitives for collective multi-GPU communication
Other
3.27k stars 829 forks source link

How NCCL utilizes shared memory with the dynamic tensor shape varies across training iterations? #1399

Open szhengac opened 3 months ago

szhengac commented 3 months ago

Hi,

I recently came across an issue when using context parallelism for splitting long sequence with NeMo and Transformer Engine. The context parallelism splits sequence length across GPUs and use p2p communications to implement a ring algorithm to accumulate the attention scores. The context parallelism would create a number of p2p pytorch communication groups.

  File "/opt/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 1459, in forward_backward_pipelining_without_interleaving
    config.grad_sync_func(model.parameters())
  File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_base_model.py", line 682, in reduce_overlap_gradients
    self._optimizer.try_grad_sync(params)
  File "/opt/NeMo/nemo/core/optim/distributed_adam.py", line 485, in try_grad_sync
    self._try_start_bucket_grad_sync(params=params)
  File "/opt/apex/apex/contrib/optimizers/distributed_fused_adam.py", line 1821, in _try_start_bucket_grad_sync
    self._start_bucket_grad_sync(filled_buckets)
  File "/opt/apex/apex/contrib/optimizers/distributed_fused_adam.py", line 1876, in _start_bucket_grad_sync
    with _coalescing_manager(group, self.device, async_ops=True) as cm:
  File "/usr/lib/python3.10/contextlib.py", line 142, in __exit__
    next(self.gen)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 1936, in _coalescing_manager
    work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts)
torch.distributed.DistBackendError: NCCL error in: /opt/pytorch/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2006, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.20.5
ncclUnhandledCudaError: Call to CUDA function failed.
Last error:
Failed to CUDA host alloc 2147483648 bytes

I checked the line 2006 in ProcessGroupNCCL.cpp in the container. The error happens when pytorch tries to create a new communicator. With context parallelism and dynamic sequence length, the p2p collectives will operate on different tensor sizes across iterations. I am not sure if the above NCCL error is related. There are plenty of available GPU memory when the error occurred. Can you please explain how NCCL utilizes shared memory (/dev/shm)?

kiskra-nvidia commented 3 months ago

As the error message indicates, please rerun with NCCL_DEBUG=INFO to get additional details. There are countless reasons why creating a communicator may fail and without the additional info it's not productive to speculate. In particular, it's not clear why you are even asking about /dev/shm, given that it's not mentioned anywhere in the included output? In general though, shared memory is one of the transport layers used by NCCL for communication within a node, especially if direct point-to-point communication between the GPUs is not available.

szhengac commented 3 months ago

nemo_sft_15044.err.log nemo_sft_15044.out.log @kiskra-nvidia Thanks for your response. I have attached the logs for this training test. The reason I asked about /dev/shm is because I noticed that a lot of shared memory is used when there are a large number of p2p nccl torch distributed groups. CPU memory usage also goes up a lot. So I want to understand how NCCL utilize /dev/shm and CPU memory. I believe GPU p2p direct connection is enabled, as shown in the attached log.

kiskra-nvidia commented 3 months ago

It looks like NCCL is running out of host memory because NCCL_WORK_FIFO_DEPTH is set to 4194304, which requires a 2GB allocation -- per process. So probably that's what's eating up your memory... The default value of this option is 65536, I think...

szhengac commented 3 months ago

I do not see NCCL_WORK_FIFO_DEPTH in the NCCL official document. Is it a hidden environment variable? Also, does it scale linearly with the number of communicators?

kiskra-nvidia commented 3 months ago

I don't know its full history but it appears to be an internal variable. Not all NCCL variables are documented because some of them are meant primarily for our own development and debugging purposes and we may not want to support them long-term. Of course, anybody searching through the source code will find them though so I wouldn't exactly call them hidden :wink:. In case of NCCL_WORK_FIFO_DEPTH in paticular, it's gone as of NCCL 2.22. And yes, it appears that the buffer is allocated for each communicator...

szhengac commented 3 months ago

Ok thanks for the clarification. If it scales linearly with the number of communicators, it does explain something. I thought NCCL would add some sharing mechanism to reduce the memory usage.

kiskra-nvidia commented 3 months ago

In principle NCCL can share resources between communicators created using ncclCommSplit with splitShare set -- but it's something that has to be explicitly requested...