intel / torch-ccl

oneCCL Bindings for Pytorch*
BSD 3-Clause "New" or "Revised" License
85 stars 23 forks source link

DDP(model) gets stocked in a cluster When run Demo.py manually #46

Open leonardozcm opened 1 year ago

leonardozcm commented 1 year ago

Torch/torch-ccl/ipex version 1.13.0 cluster node: 2 World_size: 2 All nodes have password-less connections set, and mpirun works well as the readme says:

mpirun -f ./hosts -n 2 -ppn 1 -genv OMP_NUM_THREADS=24 python demo.py 

And I try to run it manually by start training in both of the nodes:

# in node 0
RANK=0 WORLD_SIZE=2 python demo.py
# in node 1
RANK=1 WORLD_SIZE=2 python demo.py

This will stock at DDP(model):

/home/cpx/anaconda3/envs/bigdl_test/lib/python3.7/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libc10_cuda.so: cannot open shared object file: No such file or directory
  warn(f"Failed to load image Python extension: {e}")
1 2
2023-05-05 16:13:22,973 - torch.distributed.distributed_c10d - INFO - Added key: store_based_barrier_key:1 to store for rank: 1
2023-05-05 16:13:22,984 - torch.distributed.distributed_c10d - INFO - Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 2 nodes.
2023:05:05-16:13:30:(3185397) |CCL_WARN| did not find MPI-launcher specific variables, switch to ATL/OFI, to force enable ATL/MPI set CCL_ATL_TRANSPORT=mpi

This will not happen if I set dist.init_process_group(backend='gloo')

leonardozcm commented 1 year ago

Find out stocks here: https://github.com/pytorch/pytorch/blob/main/torch/nn/parallel/distributed.py#L809

abhilash1910 commented 1 year ago

@leonardozcm What backend are you using while initializing torch.init_distributed() ? The recommended backend is "ccl" and judging by the error , you might have set backend="nccl" which is why it is querying for libc10_cuda.so Could you share a snippet /reproducer is possible?