facebookincubator / gloo

Collective communications library with various primitives for multi-machine training.
Other
1.23k stars 303 forks source link

Network discovery interface? #367

Closed surak closed 1 year ago

surak commented 1 year ago

On PyTorch 1.12.0, I have a python code which tries to create a gloo group on our supercomputer.

    for j in range(2):
        ranks = range(j,8,2)
        group_gloo = torch.distributed.new_group(ranks, backend="gloo")
        if rank in ranks:
            print('gloo rank',rank,'in gloo group', ranks)

This works among nodes of the same island of the supercomputer, as gloo does at least some part of the communication over the service ethernet network, but it fails when it has to use the infiniband cards.

Our host naming scheme works as follows:

host-x points to the ethernet adapter. host-xi points to the infiniband adapter.

And when asking with something like hostname, the name returned is always that of the ethernet adapter.

For NCCL we solve the rendezvous part by forcing the MASTER_ADDR, in order to add it to the slurm script, such as:

MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)"
# Allow communication over InfiniBand cells.
MASTER_ADDR="${MASTER_ADDR}i"

And we patched pytorch to understand that (attachment is at the end)

For gloo, this doesn't work.

using an export GLOO_SOCKET_IFNAME = ib0doesn't solve the issue.

So I would like some pointers on how to solve this issue, or where should I look in the source, at least.

PyTorch-1.12.0_nic-hostname.patch

surak commented 1 year ago

I get an error like this:


Traceback (most recent call last):
  File "/p/scratch/opengptx-elm/john2/torch_test/main.py", line 45, in <module>
    main()
  File "/p/scratch/opengptx-elm/john2/torch_test/main.py", line 33, in main
    test2()
  File "/p/scratch/opengptx-elm/john2/torch_test/main.py", line 21, in test2
    group_gloo = torch.distributed.new_group(ranks, backend="gloo")
  File "/p/software/juwelsbooster/stages/2023/software/PyTorch/1.12.0-foss-2022a-CUDA-11.7/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2974, in new_group
    pg = _new_process_group_helper(
  File "/p/software/juwelsbooster/stages/2023/software/PyTorch/1.12.0-foss-2022a-CUDA-11.7/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 703, in _new_process_group_helper
    pg = ProcessGroupGloo(prefix_store, rank, world_size, timeout=timeout)
RuntimeError: [/dev/shm/strube1/juwelsbooster/PyTorch/1.12.0/foss-2022a-CUDA-11.7/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:799] connect [10.11.241.193]:43394: Connection timed out
Traceback (most recent call last):
  File "/p/scratch/opengptx-elm/john2/torch_test/main.py", line 45, in <module>
    main()
  File "/p/scratch/opengptx-elm/john2/torch_test/main.py", line 33, in main
    test2()
  File "/p/scratch/opengptx-elm/john2/torch_test/main.py", line 21, in test2
    group_gloo = torch.distributed.new_group(ranks, backend="gloo")
  File "/p/software/juwelsbooster/stages/2023/software/PyTorch/1.12.0-foss-2022a-CUDA-11.7/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2974, in new_group
    pg = _new_process_group_helper(
  File "/p/software/juwelsbooster/stages/2023/software/PyTorch/1.12.0-foss-2022a-CUDA-11.7/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 703, in _new_process_group_helper
    pg = ProcessGroupGloo(prefix_store, rank, world_size, timeout=timeout)
RuntimeError: [/dev/shm/strube1/juwelsbooster/PyTorch/1.12.0/foss-2022a-CUDA-11.7/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:799] connect [10.11.241.193]:54670: Connection timed out
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 597 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 599 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 596) of binary: /p/software/juwelsbooster/stages/2023/software/Python/3.10.4-GCCcore-11.3.0/bin/python
INFO:torch.distributed.elastic.agent.server.api:Local worker group finished (FAILED). Waiting 300 seconds for other agents to finish

The important line is this:

RuntimeError: [/dev/shm/strube1/juwelsbooster/PyTorch/1.12.0/foss-2022a-CUDA-11.7/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:799] connect [10.11.241.193]:43394: Connection timed out

where the ip address is clearly of the ethernet interface:

nslookup 10.11.241.193
193.241.11.10.in-addr.arpa      name = jwb0247.juwels.

(it's missing the i on the hostname)