NVIDIA / nccl

Optimized primitives for collective multi-GPU communication
Other
3.14k stars 791 forks source link

Multiple ncclRecv within ncclGroupStart/ncclGroupEnd seems to produce incorrect results #501

Open pritamdamania87 opened 3 years ago

pritamdamania87 commented 3 years ago

I was debugging the following issue in PyTorch with regards to nccl send/recv: https://github.com/pytorch/pytorch/issues/50092. I tried to see if I could somehow reproduce the issue in NCCL itself to isolate whether this is a nccl issue or a PyTorch implementation issue. I have shared my code which uses nccl send/recv below.

The interesting part is that when I remove the second ncclRecv and it's associated verification this works fine, but two ncclRecv don't seem to be working. Not sure if there is a bug in my code causing this. The error I see on rank 1 is the following (looks like recvbuff is all zeros):

nccl_p2p.cpp:66: int main(int, char **): Assertion `recvbuff[i] == i' failed.
#include <nccl.h>
#include <cassert>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>

int main(int argc, char* argv[]) {
  int rank = std::stoi(argv[1]);
  std::cerr << "Rank: " << rank << std::endl;
  std::string ncclid_filename = "/tmp/ncclid";
  ncclUniqueId ncclID;
  if (rank == 0) {
    assert(ncclGetUniqueId(&ncclID) == ncclSuccess);
    std::ofstream ncclid_file;
    ncclid_file.open(ncclid_filename);
    ncclid_file.write(
        reinterpret_cast<const char*>(&ncclID), NCCL_UNIQUE_ID_BYTES);
    ncclid_file.close();
  } else {
    // Read from file.
    std::ifstream ncclid_file(ncclid_filename);
    auto buffer = new char[NCCL_UNIQUE_ID_BYTES];
    ncclid_file.read(buffer, NCCL_UNIQUE_ID_BYTES);
    std::memcpy(&ncclID, buffer, NCCL_UNIQUE_ID_BYTES);
  }

  ncclComm_t comm;
  cudaSetDevice(rank);
  assert(ncclCommInitRank(&comm, 2, ncclID, rank) == ncclSuccess);
  size_t size = 100000;
  cudaStream_t stream;
  cudaStreamCreate(&stream);

  int* cudabuff;
  int* cudabuff1;
  cudaMalloc(&cudabuff, size * sizeof(int));
  cudaMalloc(&cudabuff1, size * sizeof(int));
  assert(ncclGroupStart() == ncclSuccess);
  if (rank == 0) {
    // Send
    std::vector<int> sendbuff(size);
    for (int i = 0; i < size; i++) {
      sendbuff[i] = i;
    }

    cudaMemcpy(
        cudabuff, sendbuff.data(), size * sizeof(int), cudaMemcpyHostToDevice);
    cudaMemcpy(
        cudabuff1, sendbuff.data(), size * sizeof(int), cudaMemcpyHostToDevice);
    assert(ncclSend(cudabuff, size, ncclInt32, 1, comm, stream) == ncclSuccess);
    assert(
        ncclSend(cudabuff1, size, ncclInt32, 1, comm, stream) == ncclSuccess);
  } else {
    // Recv.
    assert(ncclRecv(cudabuff, size, ncclInt32, 0, comm, stream) == ncclSuccess);
    assert(
        ncclRecv(cudabuff1, size, ncclInt32, 0, comm, stream) == ncclSuccess);
  }
  assert(ncclGroupEnd() == ncclSuccess);

  if (rank == 1) {
    auto recvbuff = new int[size];
    cudaMemcpy(recvbuff, cudabuff, size * sizeof(int), cudaMemcpyDeviceToHost);
    for (int i = 0; i < size; i++) {
      assert(recvbuff[i] == i);
    }
    cudaMemcpy(recvbuff, cudabuff1, size * sizeof(int), cudaMemcpyDeviceToHost);
    for (int i = 0; i < size; i++) {
      assert(recvbuff[i] == i);
    }
  }
  cudaDeviceSynchronize();
  return 0;
}
sjeaugey commented 3 years ago

What NCCL version are you using? Sending to the same peer multiple times was not supported in NCCL 2.7. It should work with NCCL 2.8 onwards.

pritamdamania87 commented 3 years ago

We're actually using 2.7: NCCL version 2.7.3, will try out 2.8. Btw, what is the latest stable release for NCCL? Is the latest release on this page always stable: https://github.com/NVIDIA/nccl/releases? (currently 2.9.6-1)?

jxmsML commented 3 years ago

This issue indeed seem to go away on NCCL 2.8.3 at least.