NVIDIA / nccl

Optimized primitives for collective multi-GPU communication
Other
3.26k stars 827 forks source link

Concurrent NCCL calls and bidirectional PCI-e bandwidth #204

Open bobzhuyb opened 5 years ago

bobzhuyb commented 5 years ago

This is a question about NCCL performance details. I found that multiple concurrent NCCL calls that are on the same GPUs cannot efficiently use the bidirectional PCI-e bandwidth. I am wondering whether and where I did something wrong.

My server has two V100 GPUs connected to the same PCI-e switch on the same NUMA node with PCI-e 3.0 x16. CUDA 10 and NCCL 2.4 are installed. p2pBandwidthLatencyTest shows that with P2P enabled, Unidirectional BW is ~13GB/s, and Bidirectional BW is ~25GB/s. With P2P disabled, the numbers are 10GB/s and 10GB/s. However, in the following NCCL tests, I make sure P2P is enabled, which is confirmed by NCCL_DEBUG=INFO output.

I started with https://github.com/NVIDIA/nccl-tests, and run a broadcast test, with just these two GPU. I got 12+GB/s, which is close to p2pBandwidthLatencyTest unidirectional BW. So far so good.

However, when I run two broadcast tests (in two terminals, respectively, started by hand around the same time), one with root=0, and the other with root=1, their individual BW got halved, with the total equal to the unidirectional BW of single broadcast (12GB/s). Note the last argument in below commands.

mpirun -np 2 --allow-run-as-root -x NCCL_DEBUG=INFO ./build/broadcast_perf -b 8 -e 512M -f 2 -g 1 -c 0 -r 0
mpirun -np 2 --allow-run-as-root -x NCCL_DEBUG=INFO ./build/broadcast_perf -b 8 -e 512M -f 2 -g 1 -c 0 -r 1

I am confused. These two broadcasts are using PCI-e BW in different directions, so I expect they get ~12GB/s each. But they only got 6GB/s. I then guess that's because they are in different processes, and MPS is not enabled. So I copied the example from NCCL documentation (https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/examples.html), changed it to ncclBroadcast and dirtily added another communicator (and buffers, cudaStream, etc.) to broadcast in the reverse direction. Now these two broadcasts are in the same process.

Unfortunately, I still got the same result -- two broadcasts cost exactly twice the time of one broadcast, although the communication is supposed to be in difference directions. The computation is basically zero for broadcast. What is the problem here, and how may I run concurrent NCCL calls to utilize the spare PCI-e bandwidth?

Thank you.

Below is the code

#include <stdio.h>
#include <stdlib.h>
#include "cuda_runtime.h"
#include "nccl.h"

#define CUDACHECK(cmd) do {                         \
  cudaError_t e = cmd;                              \
  if( e != cudaSuccess ) {                          \
    printf("Failed: Cuda error %s:%d '%s'\n",             \
        __FILE__,__LINE__,cudaGetErrorString(e));   \
    exit(EXIT_FAILURE);                             \
  }                                                 \
} while(0)

#define NCCLCHECK(cmd) do {                         \
  ncclResult_t r = cmd;                             \
  if (r!= ncclSuccess) {                            \
    printf("Failed, NCCL error %s:%d '%s'\n",             \
        __FILE__,__LINE__,ncclGetErrorString(r));   \
    exit(EXIT_FAILURE);                             \
  }                                                 \
} while(0)

int main(int argc, char* argv[])
{
  ncclComm_t comms[2];
  ncclComm_t comms2[2];

  //managing 2 devices
  int nDev = 2;
  int size = 1024*1024*1024;
  int devs[2] = { 0, 1};

  //allocating and initializing device buffers
  float** sendbuff = (float**)malloc(nDev * sizeof(float*));
  float** recvbuff = (float**)malloc(nDev * sizeof(float*));
  cudaStream_t* s = (cudaStream_t*)malloc(sizeof(cudaStream_t)*nDev);

  //allocating and initializing device buffers
  float** sendbuff2 = (float**)malloc(nDev * sizeof(float*));
  float** recvbuff2 = (float**)malloc(nDev * sizeof(float*));
  cudaStream_t* s2 = (cudaStream_t*)malloc(sizeof(cudaStream_t)*nDev);

  for (int i = 0; i < nDev; ++i) {
    CUDACHECK(cudaSetDevice(i));
    CUDACHECK(cudaMalloc(sendbuff + i, size * sizeof(float)));
    CUDACHECK(cudaMalloc(recvbuff + i, size * sizeof(float)));
    CUDACHECK(cudaMemset(sendbuff[i], 1, size * sizeof(float)));
    CUDACHECK(cudaMemset(recvbuff[i], 0, size * sizeof(float)));
    CUDACHECK(cudaStreamCreate(s+i));
  }

  for (int i = 0; i < nDev; ++i) {
    CUDACHECK(cudaSetDevice(i));
    CUDACHECK(cudaMalloc(sendbuff2 + i, size * sizeof(float)));
    CUDACHECK(cudaMalloc(recvbuff2 + i, size * sizeof(float)));
    CUDACHECK(cudaMemset(sendbuff2[i], 1, size * sizeof(float)));
    CUDACHECK(cudaMemset(recvbuff2[i], 0, size * sizeof(float)));
    CUDACHECK(cudaStreamCreate(s2+i));
  }

  //initializing NCCL
  NCCLCHECK(ncclCommInitAll(comms, nDev, devs));
  NCCLCHECK(ncclCommInitAll(comms2, nDev, devs));

   //calling NCCL communication API. Group API is required when using
   //multiple devices per thread
  NCCLCHECK(ncclGroupStart());
  for (int j = 0; j < 100; ++j) {
    for (int i = 0; i < nDev; ++i) {
      NCCLCHECK(ncclBroadcast((const void*)sendbuff[i], (void*)recvbuff[i], size, ncclFloat, 0,
          comms[i], s[i]));
      NCCLCHECK(ncclBroadcast((const void*)sendbuff2[i], (void*)recvbuff2[i], size, ncclFloat, 1,
          comms2[i], s2[i]));
    }
  }
  NCCLCHECK(ncclGroupEnd());

  //synchronizing on CUDA streams to wait for completion of NCCL operation
  for (int i = 0; i < nDev; ++i) {
    CUDACHECK(cudaSetDevice(i));
    CUDACHECK(cudaStreamSynchronize(s[i]));
  }

  for (int i = 0; i < nDev; ++i) {
    CUDACHECK(cudaSetDevice(i));
    CUDACHECK(cudaStreamSynchronize(s2[i]));
  }

  //free device buffers
  for (int i = 0; i < nDev; ++i) {
    CUDACHECK(cudaSetDevice(i));
    CUDACHECK(cudaFree(sendbuff[i]));
    CUDACHECK(cudaFree(recvbuff[i]));
  }

  //free device buffers
  for (int i = 0; i < nDev; ++i) {
    CUDACHECK(cudaSetDevice(i));
    CUDACHECK(cudaFree(sendbuff2[i]));
    CUDACHECK(cudaFree(recvbuff2[i]));
  }

  //finalizing NCCL
  for(int i = 0; i < nDev; ++i)
      ncclCommDestroy(comms[i]);

  for(int i = 0; i < nDev; ++i)
      ncclCommDestroy(comms2[i]);

  printf("Success \n");
  return 0;
}
sjeaugey commented 5 years ago

I would think that even with MPS enabled, the GPU is time-shared, which means it's alternatively used by one process then the other, but the two kernels of the two processes would never run concurrently.

sjeaugey commented 5 years ago

Actually I take that back ; MPS should allow concurrent execution, and it would seem that it doesn't happen.

That said, be aware that using NCCL + MPS is not supported as it might lead to a hang.

bobzhuyb commented 5 years ago

Thanks for the note for MPS. Maybe I didn't make it clear -- MPS was not enabled in all my tests. Even within the same process, two NCCL broadcasts cannot utilize the bidirectional bandwidth. It seems that they are still time-shared.

My question is, can we avoid this and make multiple NCCL calls really concurrent (in the same process)?

sjeaugey commented 5 years ago

Oh, indeed, I misread. So maybe you'd want to create the streams with cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking) so that the streams are non-blocking and can run concurrently.

Note again this usage is not supported since even if the stream are asynchronous, they are not guaranteed to make concurrent progress and not hang.

bobzhuyb commented 5 years ago

It does not work.. still taking twice time. It's not like "they are not guaranteed to make concurrent progress". They never seem to make concurrent progress.

BTW, the two broadcasts are using different streams. See s[i] and s2[i] below. They shouldn't block each other, right?

      NCCLCHECK(ncclBroadcast((const void*)sendbuff[i], (void*)recvbuff[i], size, ncclFloat, 0,
          comms[i], s[i]));
      NCCLCHECK(ncclBroadcast((const void*)sendbuff2[i], (void*)recvbuff2[i], size, ncclFloat, 1,
          comms2[i], s2[i]));
bobzhuyb commented 5 years ago

BTW, Here is how I built the code and run (time) it

nvcc test.cc -o test -I/usr/local/cuda/include/ -I/usr/local/nccl/include/ -L/usr/local/cuda/lib64 -L/usr/local/nccl/lib -lcudart -lrt -lnccl
time ./test

You can easily test the time difference with and without the second broadcast

eric-haibin-lin commented 4 years ago

@sjeaugey is this still an issue on NCCL side?

sjeaugey commented 4 years ago

@bobzhuyb The reason for serialization is the cooperative group launch mode. If you set NCCL_LAUNCH_MODE=PARALLEL you should get twice the bandwidth. This is a limitation of cooperative group launch.

Alternatively, if you used multiple processes (1 process per GPU, MPI-style parallelism) you would not have this issue as it uses parallel launch mode by default and is not prone to deadlocks.