pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

TPU pod training hangs with "Waiting to connect to client mesh master" #2137

Closed harpone closed 4 years ago

harpone commented 4 years ago

🐛 Bug

I have an instance group of 4 VMs based on the torch-xla-1.5 boot image and a v3-32 TPU pod. Running the test_train_imagenet.py training script works fine, but test_train_mp_imagenet.py hangs.

To Reproduce

Steps to reproduce the behavior:

  1. This works:

python -m torch_xla.distributed.xla_dist --tpu=node-v3-32 --conda-env=torch-xla-1.5 --env=XLA_USE_BF16=1 -- python /usr/share/torch-xla-1.5/pytorch/xla/test/test_train_imagenet.py --fake_data

  1. This doesn't:

python -m torch_xla.distributed.xla_dist --tpu=node-v3-32 --conda-env=torch-xla-1.5 --env=XLA_USE_BF16=1 -- python /usr/share/torch-xla-1.5/pytorch/xla/test/test_train_mp_imagenet.py --fake_data

I get

2020-05-28 11:50:27 10.164.15.200 [1] 2020-05-28 11:50:27.058553: I 7449 tensorflow/compiler/xla/xla_client/mesh_service.cc:208] Waiting to connect to client mesh master (300 seconds) 10.164.15.199:8477 for all the 32 TPU cores...

Environment

Doesn't the MP distributed training work with TPU pods?

dlibenzi commented 4 years ago

Can you give it a try with nightly?

dlibenzi commented 4 years ago

Daniel, can you take a look?

jysohn23 commented 4 years ago

Are you using pytorch-1.5 version as well on the TPU side?

Also, there was a bug that was fixed after 1.5, that you needed to run the job from the 1st VM in the GCE instance group that was fixed. I'm wondering if its that. But yeah if you could also try latest nightly that'd be great.

harpone commented 4 years ago

Oh wow it's working fine from the first VM! I was sshing by default into the last VM in my bash script :P

OK I think I can live with this until the next update. Thanks all! :D

PS pretty amazing that it works out of the box like that without having to set up any NCCL comms stuff like with CUDA :)

dlibenzi commented 4 years ago

The idea for us is that even in case of using GPU devices with pytorch/xla, there should not be any explicit NCCL setup to do. The xla_dist will take care of that.