NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
235 stars 43 forks source link

Some ranks failed to wait for the barrier. #2499

Open wujingyue opened 6 days ago

wujingyue commented 6 days ago

I managed to narrow down a non-deterministic barrier hang into the following reliable reproducer.

TEST_P(CommunicationTest, Barrier) {
  const auto device_id = communicator_->deviceId();
  for (int i = 0; i < 4; i++) {
    sleep(device_id);
    LOG(INFO) << "[Rank " << device_id << "] enter barrier " << i;
    communicator_->barrier();
    LOG(INFO) << "[Rank " << device_id << "] exit barrier " << i;
  }
}

To run this, check out wjy/barrier and _bn && python setup.py develop --debug && TORCH_CPP_LOG_LEVEL=INFO mpirun -np 3 bin/test_multidevice --gtest_filter=CommunicationTest.Barrier/NCCL. Note two GPUs weren't enough to reproduce this, and therefore -np 3.

I printed out the time that each rank enters and exits the barrier:

timestamp rank 0 rank 1 rank 2
:27 enter 0
:28 enter 0
:29 enter 0
:30 exit 0 exit 0 exit 0
:30 enter 1
:31 enter 1
:31 exit 1
:32 enter 2 enter 1
:32 exit 1 exit 2 exit 1
:32 enter 2
:33 enter 3
:33 exit 3
:34 enter 2
:34 exit 2 exit 2
:34 enter 3
:36 enter 3
:36 exit 3

Rank 0 hangs and eventually times out.

Observations: Rank 0 and rank 2 exits the same barrier at the same time, which is expected behavior. However, rank 1 is range free and doesn't synchronize except for the first iteration.

FWIW, some warnings got printed out at the beginning mentioning a potential hang. However, the guessed mapping seems to be correct so I doubt that's the reason.

[I628 23:23:24.590086547 ProcessGroupNCCL.cpp:3905] [PG  Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device.
[I628 23:23:25.784764891 ProcessGroupNCCL.cpp:3905] [PG  Rank 1]  using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device.
[I628 23:23:26.875546969 ProcessGroupNCCL.cpp:3905] [PG  Rank 2]  using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device.

FWIW, raw data for the table above:

[I628 22:58:27.038053515 test_multidevice_communications.cpp:384] [Rank 0] enter barrier
[I628 22:58:28.195701313 test_multidevice_communications.cpp:384] [Rank 1] enter barrier
[I628 22:58:29.193994645 test_multidevice_communications.cpp:384] [Rank 2] enter barrier
[I628 22:58:30.159562162 test_multidevice_communications.cpp:386] [Rank 1] exit barrier
[I628 22:58:30.159649661 test_multidevice_communications.cpp:386] [Rank 2] exit barrier
[I628 22:58:30.159801151 test_multidevice_communications.cpp:386] [Rank 0] exit barrier
[I628 22:58:30.159886257 test_multidevice_communications.cpp:384] [Rank 0] enter barrier
[I628 22:58:31.159694592 test_multidevice_communications.cpp:384] [Rank 1] enter barrier
[I628 22:58:31.159844356 test_multidevice_communications.cpp:386] [Rank 1] exit barrier
[I628 22:58:32.159799395 test_multidevice_communications.cpp:384] [Rank 2] enter barrier
[I628 22:58:32.159933381 test_multidevice_communications.cpp:384] [Rank 1] enter barrier
[I628 22:58:32.159961498 test_multidevice_communications.cpp:386] [Rank 2] exit barrier
[I628 22:58:32.160004101 test_multidevice_communications.cpp:386] [Rank 0] exit barrier
[I628 22:58:32.160092286 test_multidevice_communications.cpp:384] [Rank 0] enter barrier
[I628 22:58:32.160109484 test_multidevice_communications.cpp:386] [Rank 1] exit barrier
[I628 22:58:33.160227842 test_multidevice_communications.cpp:384] [Rank 1] enter barrier
[I628 22:58:33.160368688 test_multidevice_communications.cpp:386] [Rank 1] exit barrier
[I628 22:58:34.160084277 test_multidevice_communications.cpp:384] [Rank 2] enter barrier
[I628 22:58:34.160413208 test_multidevice_communications.cpp:386] [Rank 2] exit barrier
[I628 22:58:34.160519817 test_multidevice_communications.cpp:386] [Rank 0] exit barrier
[I628 22:58:34.160611678 test_multidevice_communications.cpp:384] [Rank 0] enter barrier
[I628 22:58:36.160531759 test_multidevice_communications.cpp:384] [Rank 2] enter barrier
[I628 22:58:36.160685263 test_multidevice_communications.cpp:386] [Rank 2] exit barrier
wujingyue commented 6 days ago

cc @samnordmann and @cowanmeg

wujingyue commented 6 days ago

~I'll try next to reproduce this with raw process group calls, just to isolate the problem.~

I added to the same branch a Python reproducer using PG directly. It doesn't hang but the timing is as weird as my last comment -- rank 1 doesn't synchronize except for the first iteration.

$ cat test_barrier.py
import logging
import os
import time
import torch.distributed as dist

def main():
    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")

    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
    world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
    assert world_size >= 3, "This test requires at least 3 GPUs."
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

    for i in range(4):
        time.sleep(rank)
        logging.info(f"Rank {rank}: enter barrier {i}")
        dist.barrier()
        logging.info(f"Rank {rank}: exit barrier {i}")

    dist.destroy_process_group()

if __name__ == "__main__":
    main()
$ TORCH_CPP_LOG_LEVEL=INFO TORCH_DISTRIBUTED_DEBUG=INFO mpirun -np 3 python test_barrier.py
...
2024-06-29 06:07:41,109 - Rank 0: enter barrier 0
2024-06-29 06:07:42,083 - Rank 1: enter barrier 0
2024-06-29 06:07:43,108 - Rank 2: enter barrier 0
2024-06-29 06:07:43,708 - Rank 2: exit barrier 0
2024-06-29 06:07:43,708 - Rank 1: exit barrier 0
2024-06-29 06:07:43,708 - Rank 0: exit barrier 0
2024-06-29 06:07:43,709 - Rank 0: enter barrier 1
2024-06-29 06:07:44,710 - Rank 1: enter barrier 1
2024-06-29 06:07:44,710 - Rank 1: exit barrier 1
2024-06-29 06:07:45,710 - Rank 2: enter barrier 1
2024-06-29 06:07:45,711 - Rank 0: exit barrier 1
2024-06-29 06:07:45,711 - Rank 0: enter barrier 2
2024-06-29 06:07:45,711 - Rank 2: exit barrier 1
2024-06-29 06:07:45,711 - Rank 1: enter barrier 2
2024-06-29 06:07:45,712 - Rank 1: exit barrier 2
2024-06-29 06:07:46,713 - Rank 1: enter barrier 3
2024-06-29 06:07:46,713 - Rank 1: exit barrier 3
2024-06-29 06:07:47,713 - Rank 2: enter barrier 2
2024-06-29 06:07:47,714 - Rank 2: exit barrier 2
2024-06-29 06:07:47,714 - Rank 0: exit barrier 2
2024-06-29 06:07:47,714 - Rank 0: enter barrier 3
2024-06-29 06:07:49,716 - Rank 2: enter barrier 3
2024-06-29 06:07:49,717 - Rank 2: exit barrier 3
2024-06-29 06:07:49,717 - Rank 0: exit barrier 3
...
wujingyue commented 4 days ago

Update: https://github.com/NVIDIA/Fuser/pull/2504 will fix the hang. However, the weird timing of entering/exiting the barrier mentioned in OP and https://github.com/NVIDIA/Fuser/issues/2499#issuecomment-2197771152 is still worrisome. Some ranks failed to wait for the barrier.

cowanmeg commented 3 days ago

This is a bit worrisome...I vaguely recall seeing this error with ProcessGroupUCC, but it was under slightly different conditions (world size > test size) which is why so many of the tests are disabled.

tfogal commented 3 days ago

tagging @eqy for help understanding https://github.com/NVIDIA/Fuser/issues/2499#issuecomment-2197771152

eqy commented 3 days ago

My guess is that this has something to do with what @Aidyn-A said regarding barrier() having to guess what the current device is---seems like the current device is still "0" across all ranks with this run setup. If I add a hack to force the current device to be the current rank, then things seem to line up:

import logging
import os
import time
import torch
import torch.distributed as dist

def main():
    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")

    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
    world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
    assert world_size >= 3, "This test requires at least 3 GPUs."
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    for i in range(4):
        time.sleep(rank)
        logging.info(f"Rank {rank}: enter barrier {i}")
        dist.barrier()
        logging.info(f"Rank {rank}: exit barrier {i}")

    dist.destroy_process_group()

if __name__ == "__main__":
    main()

showing two example runs

root@4a34daa55aff:/nvdl/repro/distbarrier# ./run.sh run.py
[W701 20:51:43.751653082 socket.cpp:697] [c10d] The client socket has failed to connect to [localhost]:12345 (errno: 99 - Cannot assign requested address).
[W701 20:51:43.819483082 socket.cpp:697] [c10d] The client socket has failed to connect to [localhost]:12345 (errno: 99 - Cannot assign requested address).
2024-07-01 20:51:44,705 - Rank 0: enter barrier 0
2024-07-01 20:51:45,856 - Rank 1: enter barrier 0
2024-07-01 20:51:46,858 - Rank 2: enter barrier 0
2024-07-01 20:51:47,091 - Rank 1: exit barrier 0
2024-07-01 20:51:47,091 - Rank 2: exit barrier 0
2024-07-01 20:51:47,091 - Rank 0: exit barrier 0
2024-07-01 20:51:47,092 - Rank 0: enter barrier 1
2024-07-01 20:51:48,092 - Rank 1: enter barrier 1
2024-07-01 20:51:49,093 - Rank 2: enter barrier 1
2024-07-01 20:51:49,093 - Rank 1: exit barrier 1
2024-07-01 20:51:49,093 - Rank 0: exit barrier 1
2024-07-01 20:51:49,093 - Rank 2: exit barrier 1
2024-07-01 20:51:49,093 - Rank 0: enter barrier 2
2024-07-01 20:51:50,094 - Rank 1: enter barrier 2
2024-07-01 20:51:51,095 - Rank 2: enter barrier 2
2024-07-01 20:51:51,096 - Rank 1: exit barrier 2
2024-07-01 20:51:51,096 - Rank 0: exit barrier 2
2024-07-01 20:51:51,096 - Rank 2: exit barrier 2
2024-07-01 20:51:51,096 - Rank 0: enter barrier 3
2024-07-01 20:51:52,097 - Rank 1: enter barrier 3
2024-07-01 20:51:53,098 - Rank 2: enter barrier 3
2024-07-01 20:51:53,098 - Rank 1: exit barrier 3
2024-07-01 20:51:53,098 - Rank 0: exit barrier 3
2024-07-01 20:51:53,098 - Rank 2: exit barrier 3
root@4a34daa55aff:/nvdl/repro/distbarrier# ./run.sh run.py
[W701 20:52:50.480120952 socket.cpp:697] [c10d] The client socket has failed to connect to [localhost]:12345 (errno: 99 - Cannot assign requested address).
2024-07-01 20:52:52,054 - Rank 0: enter barrier 0
2024-07-01 20:52:53,247 - Rank 1: enter barrier 0
2024-07-01 20:52:53,389 - Rank 2: enter barrier 0
2024-07-01 20:52:53,618 - Rank 1: exit barrier 0
2024-07-01 20:52:53,618 - Rank 2: exit barrier 0
2024-07-01 20:52:53,619 - Rank 0: exit barrier 0
2024-07-01 20:52:53,619 - Rank 0: enter barrier 1
2024-07-01 20:52:54,620 - Rank 1: enter barrier 1
2024-07-01 20:52:55,621 - Rank 2: enter barrier 1
2024-07-01 20:52:55,621 - Rank 1: exit barrier 1
2024-07-01 20:52:55,621 - Rank 0: exit barrier 1
2024-07-01 20:52:55,621 - Rank 2: exit barrier 1
2024-07-01 20:52:55,621 - Rank 0: enter barrier 2
2024-07-01 20:52:56,622 - Rank 1: enter barrier 2
2024-07-01 20:52:57,623 - Rank 2: enter barrier 2
2024-07-01 20:52:57,623 - Rank 0: exit barrier 2
2024-07-01 20:52:57,624 - Rank 2: exit barrier 2
2024-07-01 20:52:57,624 - Rank 0: enter barrier 3
2024-07-01 20:52:57,624 - Rank 1: exit barrier 2
2024-07-01 20:52:58,625 - Rank 1: enter barrier 3
2024-07-01 20:52:59,626 - Rank 2: enter barrier 3
2024-07-01 20:52:59,626 - Rank 1: exit barrier 3
2024-07-01 20:52:59,626 - Rank 0: exit barrier 3
2024-07-01 20:52:59,626 - Rank 2: exit barrier 3
eqy commented 3 days ago

Better update: note that barrier is not device-synchronizing, so logging is free to run ahead and print "exit" before the barrier actually completes: https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier

There is another gotcha in that naively adding torch.cuda.synchronize() won't help because without knowledge of what device to sync on it will sync on the default device which would also be incorrect.

So to actually get the intended behavior we need to sync on corresponding device before each print:

import logging
import os
import time
import torch
import torch.distributed as dist

def main():
    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")

    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
    world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
    assert world_size >= 3, "This test requires at least 3 GPUs."
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    for i in range(4):
        time.sleep(rank)
        logging.info(f"Rank {rank}: enter barrier {i}")
        dist.barrier()
        torch.cuda.synchronize(rank)
        logging.info(f"Rank {rank}: exit barrier {i}")

    dist.destroy_process_group()

if __name__ == "__main__":
    main() 
wujingyue commented 3 days ago

Better update: note that barrier is not device-synchronizing, so logging is free to run ahead and print "exit" before the barrier actually completes

I must be missing something. I understood barrier is now stream-synchronizing instead of device-synchronizing. It ought to still block CPU and therefore the second logging.info?

eqy commented 3 days ago

@wujingyue That's a good point---the comment is written in a very misleading way or there's a bug, as the allreduce used in the barrier is done in a side ncclStream so the streamSynchronize on the current stream wouldn't actually cause the host to wait for the allreduce...

I'll open a patch upstream to "fix" and see what kind of comments we get.

eqy commented 3 days ago

Opened #129908

wujingyue commented 3 days ago

Thanks, Eddie!

A separate concern: I suspect this log may fail to print the actual device being used. guessDeviceForRank uses the bound device if it's available, which is not necessarily rank % numGpus. I haven't checked whether it makes a difference in practice, but I think a more reliable way is to instead print out the result of guessDeviceForRank.

eqy commented 3 days ago

@wujingyue I checked that function during debugging and it looks like it guesses "correctly" for now in that it matches what you were printing