Open bram-w opened 1 year ago
I have the same problem on a 95M parameter transformer model - the checkpoint only takes ~5 seconds to save normally. When training, it did succeed in saving the first checkpoint (1.2 GB) but failed on the second checkpoint with this traceback. Environment is Google Colab TPU, and the saving was to my google drive folder mounted with
from google.colab import drive
drive.mount('/content/drive')
By the way, I am calling the xm.save via a closure added by xm.add_step_closure. I found that it doesn't work without it.
Context is here
Exception in device=TPU:7: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Socket closed (14)
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 334, in _mp_start_fn
_start_fn(index, pf_cfg, fn, args)
File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 328, in _start_fn
fn(gindex, *args)
File "/usr/local/lib/python3.10/dist-packages/aiayn/train.py", line 144, in _mp_fn
train_loop_xla(run)
File "/usr/local/lib/python3.10/dist-packages/aiayn/train.py", line 189, in train_loop_xla
enc_input, dec_input, load_step, epoch = next(run.loader)
File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/parallel_loader.py", line 30, in __next__
return self.next()
File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/parallel_loader.py", line 42, in next
xm.mark_step()
File "/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py", line 956, in mark_step
devctx = _run_step_closures()
File "/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py", line 938, in _run_step_closures
closure()
File "/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py", line 920, in <lambda>
step_closures.append(lambda a=args: closure(*a))
File "/usr/local/lib/python3.10/dist-packages/aiayn/pause.py", line 73, in save
xm.save(ckpt, path)
File "/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py", line 1065, in save
rendezvous('torch_xla.core.xla_model.save')
File "/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py", line 1112, in rendezvous
return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Socket closed (14)
@hrbigelow Seems like you are using the XRT(since you are on Colab), I took a look at https://github.com/pytorch/xla/blob/r2.0/third_party/xla_client/mesh_service.cc#L323 and I was not able to find where to tune the tolerance. If you are able to use TPUVM that will allow you to use PJRT and we implemented the redezous using the all_gather
cc ops instead of GRPC server based solution. I believed that actually doesn't have a timeout.
In a recent pr we also removed the rendezvous
in xm.save
, https://github.com/pytorch/xla/commit/4f7f3dcd4b82aaa9acfcb63057cebd80e989ebf7. To unblock youself you can also try to set the sync
to false.
❓ Questions and Help
Hi, this might be a basic question but how do I increase the timeout of
xm.rendezvous()
? I'm training a large model and due to the system we're training on saving can take >5 minutes which results in timeout errors such as2023-03-29 13:52:59 172.16.96.171 [1] RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Connection reset by peer (14)
Sorry if I missed this in the documentation. I might have misinterpreted this error but it seems like a basic rendezvous timeout? Thanks!