pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 483 forks source link

Increasing rendezvous timeout patience? #4831

Open bram-w opened 1 year ago

bram-w commented 1 year ago

❓ Questions and Help

Hi, this might be a basic question but how do I increase the timeout of xm.rendezvous()? I'm training a large model and due to the system we're training on saving can take >5 minutes which results in timeout errors such as

2023-03-29 13:52:59 172.16.96.171 [1] RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Connection reset by peer (14)

Sorry if I missed this in the documentation. I might have misinterpreted this error but it seems like a basic rendezvous timeout? Thanks!

hrbigelow commented 1 year ago

I have the same problem on a 95M parameter transformer model - the checkpoint only takes ~5 seconds to save normally. When training, it did succeed in saving the first checkpoint (1.2 GB) but failed on the second checkpoint with this traceback. Environment is Google Colab TPU, and the saving was to my google drive folder mounted with

from google.colab import drive
drive.mount('/content/drive')

By the way, I am calling the xm.save via a closure added by xm.add_step_closure. I found that it doesn't work without it.

Context is here

Exception in device=TPU:7: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Socket closed (14)
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 334, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 328, in _start_fn
    fn(gindex, *args)
  File "/usr/local/lib/python3.10/dist-packages/aiayn/train.py", line 144, in _mp_fn
    train_loop_xla(run)
  File "/usr/local/lib/python3.10/dist-packages/aiayn/train.py", line 189, in train_loop_xla
    enc_input, dec_input, load_step, epoch = next(run.loader)
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/parallel_loader.py", line 30, in __next__
    return self.next()
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/parallel_loader.py", line 42, in next
    xm.mark_step()
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py", line 956, in mark_step
    devctx = _run_step_closures()
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py", line 938, in _run_step_closures
    closure()
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py", line 920, in <lambda>
    step_closures.append(lambda a=args: closure(*a))
  File "/usr/local/lib/python3.10/dist-packages/aiayn/pause.py", line 73, in save
    xm.save(ckpt, path)
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py", line 1065, in save
    rendezvous('torch_xla.core.xla_model.save')
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py", line 1112, in rendezvous
    return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Socket closed (14)
JackCaoG commented 1 year ago

@hrbigelow Seems like you are using the XRT(since you are on Colab), I took a look at https://github.com/pytorch/xla/blob/r2.0/third_party/xla_client/mesh_service.cc#L323 and I was not able to find where to tune the tolerance. If you are able to use TPUVM that will allow you to use PJRT and we implemented the redezous using the all_gather cc ops instead of GRPC server based solution. I believed that actually doesn't have a timeout.

In a recent pr we also removed the rendezvous in xm.save, https://github.com/pytorch/xla/commit/4f7f3dcd4b82aaa9acfcb63057cebd80e989ebf7. To unblock youself you can also try to set the sync to false.