microsoft / mscclpp

MSCCL++: A GPU-driven communication stack for scalable AI applications
MIT License
246 stars 38 forks source link

[Bug] Program hangs at proxy channel `wait()` #285

Closed liangyuRain closed 4 months ago

liangyuRain commented 7 months ago

Hi, we have encountered a problem with proxy channel's signal/wait. We have A100 GPUs connected via NVSwitch, and the proxy channels are using CudaIpc connections. We try to do the following to synchronize across all GPUs before a collective communication:

extern "C" __global__ void __launch_bounds__(1024) allgather(...) {
    ...
    if (threadId < nchannels) {
        proxy_channels[threadId].signal();
        proxy_channels[threadId].wait();
        ...
    }
}

However, the program hangs at wait(). After some debugging, we discovered that the cudaMemcpyAsync in CudaIpcConnection::updateAndSync seems to be never completed. We tried to add MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream_)); immediately after the cudaMemcpyAsync, and the program just hangs at cudaStreamSynchronize. Changing cudaMemcpyAsync to cudaMemcpy also hangs.

It is also found that if we have cudaDeviceSynchronize immediately after this allgather kernel, the program does not hang. It only hangs when there is another kernel launch on the stream after allgather without cudaDeviceSynchronize. We wonder if you have any thought on this issue. Thanks!

Code to setup the proxy channels:

int main(int argc, char* argv[]) {
    // Initialize the MPI environment
    MPI_Init(&argc, &argv);

    // Get the number of processes
    int nranks;
    MPI_Comm_size(MPI_COMM_WORLD, &nranks);

    // Get the rank of the process
    int rank;
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    CUDA_CHECK(cudaSetDevice(rank));

    // Print off a hello world message
    std::cout << "Hello world from rank " << rank << " out of " << nranks << " ranks" << std::endl;

    // Initialize Communicator
    auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, nranks);
    mscclpp::UniqueId uniqueId;
    if (rank == 0) uniqueId = bootstrap->createUniqueId();
    MPI_Bcast(&uniqueId, sizeof(uniqueId), MPI_BYTE, 0, MPI_COMM_WORLD);
    bootstrap->initialize(uniqueId);
    auto comm = std::make_shared<mscclpp::Communicator>(bootstrap);

    // Initialize Connections
    std::vector<std::shared_ptr<mscclpp::Connection>> connections;
    std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
    for (int r = 0; r < nranks; ++r) {
        if (r == rank) continue;
        mscclpp::Transport transport = mscclpp::Transport::CudaIpc;
        connectionFutures.push_back(comm->connectOnSetup(r, 0, transport));
    }
    comm->setup();
    std::transform(
        connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
        [](const mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>& future) { return future.get(); });

    MPI_Barrier(MPI_COMM_WORLD);

    ...
}

void setupProxyChannels(std::shared_ptr<mscclpp::ProxyService> service,
                        std::shared_ptr<mscclpp::Communicator> comm,
                        std::vector<std::shared_ptr<mscclpp::Connection>> connections,
                        mscclpp::DeviceHandle<mscclpp::SimpleProxyChannel>** proxyChannelHandlesCuda,
                        Element* input, Element* output, int input_size, int output_size) {
    const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc;
    mscclpp::RegisteredMemory inputBuffRegMem = comm->registerMemory(input, input_size * sizeof(Element), allTransports);
    mscclpp::RegisteredMemory outputBuffRegMem;
    if (input != output) outputBuffRegMem = comm->registerMemory(output, output_size * sizeof(Element), allTransports);

    std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemories;
    mscclpp::RegisteredMemory& localRegMemory = (input != output) ? outputBuffRegMem : inputBuffRegMem;

    for (int r = 0; r < nranks; ++r) {
        if (r == rank) continue;
        comm->sendMemoryOnSetup(localRegMemory, r, 0);
        auto remoteMemory = comm->recvMemoryOnSetup(r, 0);
        remoteRegMemories.push_back(remoteMemory);
    }
    comm->setup();
    for (int i = 0; i < connections.size(); ++i) {
        proxyChannelHandles.push_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel(
            service->proxyChannel(service->buildAndAddSemaphore(*comm, connections[i])),
            service->addMemory(remoteRegMemories[i].get()), service->addMemory(inputBuffRegMem)
        )));
    }
    comm->setup();

    assert(connections.size() == nranks - 1);
    CUDA_CHECK(cudaMalloc(proxyChannelHandlesCuda, (nranks - 1) * sizeof(mscclpp::DeviceHandle<mscclpp::SimpleProxyChannel>)));
    CUDA_CHECK(cudaMemcpy(*proxyChannelHandlesCuda, &proxyChannelHandles[proxyChannelHandles.size() - (nranks - 1)],
                          (nranks - 1) * sizeof(mscclpp::DeviceHandle<mscclpp::SimpleProxyChannel>), cudaMemcpyHostToDevice));
}
chhwang commented 6 months ago

Hi @liangyuRain, if cudaMemcpyAsync hangs, it is highly likely that your application prevents other parallel streams run concurrently. Are you using the default stream in your application?