NVIDIA / nccl

Optimized primitives for collective multi-GPU communication
Other
3.23k stars 815 forks source link

Illegal memory access when executing grouped ncclSend/ncclRecvs on multiple nodes #1021

Open chenyu-jiang opened 1 year ago

chenyu-jiang commented 1 year ago

Hi,

I am trying to implement a special AllToAllv where each rank have multiple data chunks to send to every other rank (each chunk can be of different size) using grouped ncclSend and ncclRecvs. However, I am encountering error: an illegal memory access was encountered with some input sizes when running on multiple nodes. The following self-contained example code reproduces the error:

#include <iostream>
#include <vector>
#include <cuda_runtime.h>
#include <mpi.h>
#include <nccl.h>

#define CUDACHECK(cmd) do {                         \
  cudaError_t err = cmd;                            \
  if (err != cudaSuccess) {                         \
    printf("Failed: Cuda error %s:%d '%s'\n",       \
        __FILE__,__LINE__,cudaGetErrorString(err)); \
    exit(EXIT_FAILURE);                             \
  }                                                 \
} while(0)

#define NCCLCHECK(cmd) do {                         \
  ncclResult_t res = cmd;                           \
  if (res != ncclSuccess) {                         \
    printf("Failed, NCCL error %s:%d '%s'\n",       \
        __FILE__,__LINE__,ncclGetErrorString(res)); \
    exit(EXIT_FAILURE);                             \
  }                                                 \
} while(0)

int main(int argc, char* argv[]) {

  MPI_Init(&argc, &argv);
  int rank;
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  int world_size;
  MPI_Comm_size(MPI_COMM_WORLD, &world_size);
  int local_rank;
  if(char* local_rank_char = std::getenv("OMPI_COMM_WORLD_LOCAL_RANK")) {
    local_rank = atoi(local_rank_char);
  } else {
    std::cout << "Error: local rank not found" << std::endl;
    exit(1);
  }
  MPI_Barrier(MPI_COMM_WORLD);
  std::cout << "Rank: " << rank << ", local rank: " << local_rank << ", world_size: " << world_size << std::endl;
  CUDACHECK(cudaSetDevice(local_rank));
  MPI_Barrier(MPI_COMM_WORLD);
  ncclUniqueId nccl_id;
  if (rank == 0) {
    NCCLCHECK(ncclGetUniqueId(&nccl_id));
  }
  MPI_Bcast((void *)&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD);
  ncclComm_t nccl_comm;
  NCCLCHECK(ncclCommInitRank(&nccl_comm, world_size, nccl_id, rank));

  cudaStream_t comm_stream;
  CUDACHECK(cudaStreamCreateWithFlags(&comm_stream, cudaStreamNonBlocking));

  // Each rank sends/receives 2 chunks of data to/from each other rank
  // maximum size of each chunk is 768 * 768 * 4 bytes
  // send_counts specifies the number of elements of each chunk
  // (assume 4 bytes per element)

  // 16 GPUs case (illegal mem access on rank 2)
  // std::vector<size_t> send_counts = {0, 0, 589824, 0, 768, 589824, 589824, 589824, 0, 33792, 0, 347136, 0, 63744, 85248, 0, 5376, 28416, 589824, 11520, 0, 475392, 1536, 3072, 138240, 0, 0, 0, 2304, 0, 36864, 37632};

  // another 8 GPUs case (illegal mem access on rank 1)
  std::vector<size_t> send_counts = {294912, 3072, 3072, 294912, 294912, 294912, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072, 3072};
  std::vector<size_t> recv_counts;
  for (int i = 0; i < world_size; i++) {
    recv_counts.push_back(send_counts[rank * 2]);
    recv_counts.push_back(send_counts[rank * 2 + 1]);
  }

  int n_chunks_per_device = 2;
  int dtype_size_in_bytes = 4;

  void* send_buf;
  void* recv_buf;
  CUDACHECK(cudaMalloc(&send_buf, world_size * n_chunks_per_device * 768 * 768 * dtype_size_in_bytes));
  CUDACHECK(cudaMalloc(&recv_buf, world_size * n_chunks_per_device * 768 * 768 * dtype_size_in_bytes));
  CUDACHECK(cudaMemset(send_buf, 0, world_size * n_chunks_per_device * 768 * 768 * dtype_size_in_bytes));
  CUDACHECK(cudaMemset(recv_buf, 0, world_size * n_chunks_per_device * 768 * 768 * dtype_size_in_bytes));

  char* send_buffer = static_cast<char*>(send_buf);
  char* recv_buffer = static_cast<char*>(recv_buf);

  size_t per_chunk_bytes = 768 * 768 * dtype_size_in_bytes;

  NCCLCHECK(ncclGroupStart());
  for(int i=0; i< world_size; i++) {
    for (int j=0; j< n_chunks_per_device; j++) {
      // if (i != rank) {
        size_t send_nbytes = send_counts[i * n_chunks_per_device + j] * dtype_size_in_bytes;
        if (send_nbytes > 0) {
          NCCLCHECK(ncclSend(static_cast<void*>(send_buffer + i * per_chunk_bytes * n_chunks_per_device + j * per_chunk_bytes),
                              send_nbytes, ncclChar, i, (ncclComm_t)nccl_comm, (cudaStream_t)comm_stream));
        }
        size_t recv_nbytes = recv_counts[i * n_chunks_per_device + j] * dtype_size_in_bytes;
        if (recv_nbytes > 0) {
          NCCLCHECK(ncclRecv(static_cast<void*>(recv_buffer + i * per_chunk_bytes * n_chunks_per_device + j * per_chunk_bytes),
                              recv_nbytes, ncclChar, i, (ncclComm_t)nccl_comm, (cudaStream_t)comm_stream));
        }
      // }
    }
  }
  NCCLCHECK(ncclGroupEnd());

  cudaError_t err = cudaStreamSynchronize(comm_stream);
  if (err != cudaSuccess) {
    std::cout << "Error: " << cudaGetErrorString(err) << std::endl;
    exit(1);
  }
  std::cout << "Rank: " << rank << ", local rank: " << local_rank << " finished" << std::endl;

  NCCLCHECK(ncclCommDestroy(nccl_comm));
  MPI_Finalize();

  return 0;
}

I am running the code on two AWS p4de instances, with NCCL version 2.18.5+cuda11.0. The compiled executable is launched through MPI on 4 GPUs in each node. Several observations:

  1. When running on a single node (8 GPUs), the error disappears.
  2. The error disappears when running on only 2 GPUs in each node (fails with 3 or more GPUs on each node).
  3. When skipping sending to self, (i.e., uncomment if (i != rank)), the error is gone.
  4. When changing any one of the values 294912 in send_counts (e.g., to 3072), the problem disappears.

After many attempts, I am still unable to identify the error. Any help would be much appreciated!

Below is the log from the failed rank (when setting NCCL_DEBUG=INFO), for your information. (Note: EFA is disabled on the instance as I try to isolate the cause of the error. The error still exists with EFA enabled.)


Rank: 1, local rank: 1, world_size: 8
ip-172-31-21-100:18571:18571 [1] NCCL INFO cudaDriverVersion 12000
ip-172-31-21-100:18571:18571 [1] NCCL INFO Bootstrap : Using ens32:172.31.21.100<0>
ip-172-31-21-100:18571:18571 [1] NCCL INFO NET/Plugin: Failed to find ncclCollNetPlugin_v6 symbol.
ip-172-31-21-100:18571:18571 [1] NCCL INFO NET/Plugin: Failed to find ncclCollNetPlugin symbol (v4 or v5).
ip-172-31-21-100:18571:18571 [1] NCCL INFO NET/OFI Using aws-ofi-nccl 1.6.0
ip-172-31-21-100:18571:18571 [1] NCCL INFO NET/OFI Configuring AWS-specific options
ip-172-31-21-100:18571:18571 [1] NCCL INFO NET/OFI Setting provider_filter to efa
ip-172-31-21-100:18571:18571 [1] NCCL INFO NET/OFI Setting NCCL_PROTO to "simple"
ip-172-31-21-100:18571:18571 [1] NCCL INFO NET/OFI Setting FI_EFA_FORK_SAFE environment variable to 1
ip-172-31-21-100:18571:18571 [1] NCCL INFO NET/OFI Running on p4de.24xlarge platform, Setting NCCL_TOPO_FILE environment variable to /usr/local/cuda-11.8/efa/share/aws-ofi-nccl/xml/p4de-24xl-topo.xml

ip-172-31-21-100:18571:18571 [1] nccl_net_ofi_init:1472 NCCL WARN NET/OFI aws-ofi-nccl initialization failed
ip-172-31-21-100:18571:18571 [1] NCCL INFO NET/IB : No device found.
ip-172-31-21-100:18571:18571 [1] NCCL INFO NET/Socket : Using [0]ens32:172.31.21.100<0>
ip-172-31-21-100:18571:18571 [1] NCCL INFO Using network Socket
ip-172-31-21-100:18571:18571 [1] NCCL INFO comm 0x55661a7ffe80 rank 1 nranks 8 cudaDev 1 nvmlDev 1 busId 101d0 commId 0xaf2ae6c5f9de423 - Init START
ip-172-31-21-100:18571:18571 [1] NCCL INFO NCCL_TOPO_FILE set by environment to /usr/local/cuda-11.8/efa/share/aws-ofi-nccl/xml/p4de-24xl-topo.xml
ip-172-31-21-100:18571:18571 [1] NCCL INFO Setting affinity for GPU 1 to ff,ffff0000,00ffffff
ip-172-31-21-100:18571:18571 [1] NCCL INFO Trees [0] 2/-1/-1->1->0 [1] 2/-1/-1->1->0
ip-172-31-21-100:18571:18571 [1] NCCL INFO P2P Chunksize set to 131072
ip-172-31-21-100:18571:18571 [1] NCCL INFO Channel 00/0 : 1[1] -> 2[2] via P2P/IPC/read
ip-172-31-21-100:18571:18571 [1] NCCL INFO Channel 01/0 : 1[1] -> 2[2] via P2P/IPC/read
ip-172-31-21-100:18571:18571 [1] NCCL INFO Connected all rings
ip-172-31-21-100:18571:18571 [1] NCCL INFO Channel 00/0 : 1[1] -> 0[0] via P2P/IPC/read
ip-172-31-21-100:18571:18571 [1] NCCL INFO Channel 01/0 : 1[1] -> 0[0] via P2P/IPC/read
ip-172-31-21-100:18571:18571 [1] NCCL INFO Connected all trees
ip-172-31-21-100:18571:18571 [1] NCCL INFO NCCL_PROTO set by environment to simple
ip-172-31-21-100:18571:18571 [1] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
ip-172-31-21-100:18571:18571 [1] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
ip-172-31-21-100:18571:18571 [1] NCCL INFO comm 0x55661a7ffe80 rank 1 nranks 8 cudaDev 1 nvmlDev 1 busId 101d0 commId 0xaf2ae6c5f9de423 - Init COMPLETE
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 00/1 : 1[1] -> 2[2] via P2P/IPC/read
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 01/1 : 1[1] -> 2[2] via P2P/IPC/read
ip-172-31-21-100:18571:18596 [1] NCCL INFO NET/Socket: Using 2 threads and 8 sockets per thread
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 00/1 : 7[3] -> 1[1] [receive] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18596 [1] NCCL INFO NET/Socket: Using 2 threads and 8 sockets per thread
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 01/1 : 7[3] -> 1[1] [receive] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 00/1 : 1[1] -> 3[3] via P2P/IPC/read
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 01/1 : 1[1] -> 3[3] via P2P/IPC/read
ip-172-31-21-100:18571:18596 [1] NCCL INFO NET/Socket: Using 2 threads and 8 sockets per thread
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 00/1 : 6[2] -> 1[1] [receive] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18596 [1] NCCL INFO NET/Socket: Using 2 threads and 8 sockets per thread
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 01/1 : 6[2] -> 1[1] [receive] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 00/1 : 1[1] -> 4[0] [send] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 01/1 : 1[1] -> 4[0] [send] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18596 [1] NCCL INFO NET/Socket: Using 2 threads and 8 sockets per thread
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 00/1 : 5[1] -> 1[1] [receive] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18596 [1] NCCL INFO NET/Socket: Using 2 threads and 8 sockets per thread
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 01/1 : 5[1] -> 1[1] [receive] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 00/1 : 1[1] -> 5[1] [send] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 01/1 : 1[1] -> 5[1] [send] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18596 [1] NCCL INFO NET/Socket: Using 2 threads and 8 sockets per thread
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 00/1 : 4[0] -> 1[1] [receive] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18596 [1] NCCL INFO NET/Socket: Using 2 threads and 8 sockets per thread
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 01/1 : 4[0] -> 1[1] [receive] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 00/1 : 1[1] -> 6[2] [send] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 01/1 : 1[1] -> 6[2] [send] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 00/1 : 1[1] -> 7[3] [send] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 01/1 : 1[1] -> 7[3] [send] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 01/1 : 1[1] -> 6[2] [send] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 00/1 : 1[1] -> 7[3] [send] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 01/1 : 1[1] -> 7[3] [send] via NET/Socket/0/Shared
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 00/1 : 1[1] -> 0[0] via P2P/IPC/read
ip-172-31-21-100:18571:18606 [1] NCCL INFO Channel 01/1 : 1[1] -> 0[0] via P2P/IPC/read
Error: an illegal memory access was encountered
sjeaugey commented 1 year ago

I think you wanted to use i * 2 and i * 2 + 1 here:

  for (int i = 0; i < world_size; i++) {
    recv_counts.push_back(send_counts[rank * 2]);
    recv_counts.push_back(send_counts[rank * 2 + 1]);
  }

Also you may want to check return codes of all CUDA and NCCL calls (using e.g. NCCLCHECK or CUDACHECK macros). Not doing so can make things hard to debug.

chenyu-jiang commented 1 year ago

Hi @sjeaugey Thanks for the prompt response!

I feel it is correct to use rank * 2 and rank * 2 + 1 here: The goal is to make every rank send send_counts[0]*4, send_counts[1]*4 bytes of data to rank 0, send_counts[2]*4, send_counts[3]*4 bytes of data to rank 1 and so on. So rank r will receive send_counts[2*r]*4, send_counts[2*r + 1]*4 bytes of data from every rank.

As suggested, I have added NCCLCHECK and CUDACHECK around all CUDA and NCCL calls (as reflected in the updated example code). But none of them catches any error and the output is the same.

sjeaugey commented 1 year ago

Ah. My bad, indeed rank is correct. Sorry I misinterpreted how you were using this array.

Thanks for the confirmation.

sjeaugey commented 1 year ago

I can repro and I do see the local copy copying with NULL as the recvBuff. I should be able to figure this out quickly.

sjeaugey commented 1 year ago

Ok so indeed it's a bug with how we aggregate operations; in particular for self-sendrecv, we need to ensure they're next to each other.

Setting NCCL_NCHANNELS_PER_NET_PEER=1 should work as a workaround until this is resolved.

chenyu-jiang commented 1 year ago

Thanks! Will setting NCCL_NCHANNELS_PER_NET_PEER=1 affect communication speed? I am currently getting around the problem by skipping self-sendrecvs and manually adding cudaMemcpyAsyncs. Of course it would be easier if self-sendrecvs are correctly handled by NCCL.

arttianezhu commented 1 year ago

Hi @sjeaugey can you provide some more context to local copy copying with NULL as the recvBuff? Where can we identify this in the NCCL code?

sjeaugey commented 1 year ago

It's pretty complex, and I'm trying to find a way to fix this without changing too much code.

It's tied to how we pack operations into ncclWorkElem structures. Self send/recv are supposed to be next to each other for the code to work, but in that precise case it isn't.

sjeaugey commented 1 year ago

Here is a patch which should fix the bug.

diff --git a/src/enqueue.cc b/src/enqueue.cc
index dbb9865bc..71bf45a60 100644
--- a/src/enqueue.cc
+++ b/src/enqueue.cc
@@ -633,7 +633,6 @@ static ncclResult_t scheduleP2pTasksToPlan(
     for (int i=0; i < tasks->p2pOrderSteps; i++) {
       int sendPeer = sendOrder[i];
       int recvPeer = recvOrder[i];
-      if ((i % (NCCL_MAX_WORK_ELEMENTS_P2P/2)) == 0) fuseOk = false;
       struct ncclTaskP2p* send = sendPeer != -1 ? ncclIntruQueueHead(&peers[sendPeer].sendQueue) : NULL;
       struct ncclTaskP2p* recv = recvPeer != -1 ? ncclIntruQueueHead(&peers[recvPeer].recvQueue) : NULL;
       if (sendPeer == comm->rank) {
@@ -669,6 +668,7 @@ static ncclResult_t scheduleP2pTasksToPlan(
         if (send) sendBytes -= send->chunk*sendChunkBytesMax;

         do {
+          if ((i % (NCCL_MAX_WORK_ELEMENTS_P2P/2)) == 0) fuseOk = false;
           ssize_t recvChunkBytes = std::min(recvBytes, recvChunkBytesMax); // -1 preserved
           ssize_t sendChunkBytes = std::min(sendBytes, sendChunkBytesMax);
           if (recvChunkBytes != 0) {

Please confirm this is fixing the issue.

The goal of fuseOk is to avoid fusing operations from different nodes, and make sure self-communication would be at the beginning at the workElem. But fuseOk was only set right for the first chunk; if we split the operation on multiple channels, the second channel may experience unwanted fusion, potentially causing hangs and breaking self-communication.

chenyu-jiang commented 1 year ago

Thanks! The error is indeed gone after applying the patch.