DeepGraphLearning / torchdrug

A powerful and flexible machine learning platform for drug discovery
https://torchdrug.ai/
Apache License 2.0
1.44k stars 200 forks source link

[Bug Report] Distributed Training RuntimeError #156

Open mrzzmrzz opened 1 year ago

mrzzmrzz commented 1 year ago

Hey, I found a bug when using Distributed Data Parallel (DDP) training on different nodes.

I use 4 GPUs (two GPUs per node and I use two nodes at the same time). However, I cannot run the code successfully. But I can run two GPUs on one node successfully.

Here is the log:

RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:210] address family mismatch
Traceback (most recent call last):
  File "script/downstream.py", line 47, in <module>
    working_dir = util.create_working_directory(cfg)
  File "/home/mrzz/util.py", line 38, in create_working_directory
    comm.init_process_group("nccl", init_method="env://")
  File "/home/miniconda3/envs/torchdrug/lib/python3.8/site-packages/torchdrug/utils/comm.py", line 67, in init_process_group
    cpu_group = dist.new_group(backend="gloo")
  File "/home/miniconda3/envs/torchdrug/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2503, in new_group
    pg = _new_process_group_helper(group_world_size,
  File "/home/miniconda3/envs/torchdrug/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 588, in _new_process_group_helper
    pg = ProcessGroupGloo(
RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:210] address family mismatch
Killing subprocess 1345179
Killing subprocess 1345180

Then I comment some codes in the comm.py, luckily I succeed in running the code.

def init_process_group(backend, init_method=None, **kwargs):
    """
    Initialize CPU and/or GPU process groups.

    Parameters:
        backend (str): Communication backend. Use ``nccl`` for GPUs and ``gloo`` for CPUs.
        init_method (str, optional): URL specifying how to initialize the process group
    """
    global cpu_group
    global gpu_group

    dist.init_process_group(backend, init_method, **kwargs)
    gpu_group = dist.group.WORLD
    # if backend == "nccl":
    # cpu_group = dist.new_group(backend="gloo")
    # else:
    cpu_group = gpu_group

It seems like that when running on multiple nodes, the init_process_group method can't be initialized when creating the CPU dist_group by dist.new_group(backend="gloo").

I am not sure whether the analysis is right, maybe you can think about this bug more comprehensively. Thank you for your work.

KiddoZhu commented 1 year ago

Hi! Thanks for the bug report. I could run multi-node multi-GPU training in the past. Could you provide your PyTorch version?

From my understanding, NCCL is only implemented for GPUs. If we want to communicate any CPU tensor, the nccl group will complain there -- and that's why we additionally initialize a gloo group.

mrzzmrzz commented 1 year ago

Hey, Here is my system environment setting for NCCL

NCCL_SOCKET_IFNAME=en,eth,em,bond 
NCCL_P2P_DISABLE=0
LD_LABRARY_PATH=:/usr/local/cuda/lib64 
NCCL_DEBUG=INFO

My PyTorch and CUDA version is torch 1.8.0+cu111.

After disabling the CPU group, I can run the multi-node multi-GPU training successfully, and the final result seems normal.

KiddoZhu commented 1 year ago

I will take a close look.

Note most multi-node multi-GPU training don't require CPU tensor communication, so they should be fine. The only one used in TorchDrug is knowledge graph reasoning, since storing the intermediate results on GPU may overflow the GPU memory.

KiddoZhu commented 1 year ago

@Mrz-zz I can successfully run the distributed version of NBFNet, which involves both GPU (NCCL) and CPU (gloo) communications. The commands I used for two nodes are:

python -m torch.distributed.launch --nproc_per_nodes=4 --nnodes=2 --node_rank=0 --master_addr=xxx ...
python -m torch.distributed.launch --nproc_per_nodes=4 --nnodes=2 --node_rank=1 --master_addr=xxx ...

where master_addr is the alias or IP address to the node of rank 0. I got the alias from import platform; print(platform.node()). The two nodes are in the same LAN. My environment is based on PyTorch 1.8.1 and CUDA 11.2, so it's roughly the same.

I am confused about the reason for your case. There is also a similar question in the PyTorch forum, which suggests using netcat command to test your network. Maybe you can have a try?

mrzzmrzz commented 1 year ago

I guess something is wrong with the GLOO_SOCKET_IFNAME. Can I refer to your OS environment parameter setting about the GLOO_SOCKET_IFNAME, NCCL_P2P_DISABLE and GLOO_SOCKET_IFNAME?

KiddoZhu commented 1 year ago

Can I refer to your OS environment parameter setting about the GLOO_SOCKET_IFNAME, NCCL_P2P_DISABLE and GLOO_SOCKET_IFNAME?

All of them are empty in my environment.