pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 462 forks source link

sporadic error: mesh_service.cc:259 : Check failed: impl_->channel->WaitForConnected #2216

Closed hrbigelow closed 4 years ago

hrbigelow commented 4 years ago

🐛 Bug

UPDATE 2: Okay, so it turns out this has nothing to do with any of the actual body of the _map_fn. The error is occurring before the _map_fn is even entered. It is still sporadic, but once the error occurs, it tends to occur repeatedly, and it seems like the only remedy is to disconnect the TPU Colab runtime and reconnect. When it occurs, I notice that the call to xmp.spawn seems to hang for about five minutes, and then crashes with the error messages (as below).

UPDATE: It turns out that if I completely remove the calls to xm.mesh_reduce, the exact same thing happens. I then set num_workers=4 (instead of num_workers=0) in torch.utils.data.DataLoader (while still leaving the xm.mesh_reduce calls commented out, and that now runs. Adding back the xm.mesh_reduce calls in now produces unpickling errors. Please stand by, I will fiddle with this further.

Hi Davide, Jin Young et al,

After I incorporated two calls to xm.mesh_reduce (using scalar tensor arguments), I have seen an error sporadically. Usually, the first time I start the TPU runtime, the code runs fine, but after a few starts/stops, I get the error below.

Just for background, the reason I'd like to do this is that I am reporting training loss and other statistics every 50 steps. But, instead of reporting the individual statistics compiled from each replica, I would like to report the statistics that actually correspond to the xm.optimizer_step, so I am first combining those stats using the mesh_reduce calls, and then reporting them only from the master ordinal. Is there a better way?

Here is an excerpt:

Failed to connect to client mesh master: 13f6e0717ed2:40581
Traceback (most recent call last):
  File "/content/ae-wavenet/train.py", line 38, in <module>
    fire.Fire(run)
  File "/usr/local/lib/python3.6/dist-packages/fire/core.py", line 138, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/usr/local/lib/python3.6/dist-packages/fire/core.py", line 468, in _Fire
    target=component.__name__)
  File "/usr/local/lib/python3.6/dist-packages/fire/core.py", line 672, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/content/ae-wavenet/train.py", line 34, in run
    xmp.spawn(_mp_fn, args=(hps, dat_file), nprocs=8)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 296, in spawn
    start_method=start_method)
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 158, in start_processes
    while not context.join():
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 119, in join
    raise Exception(msg)
Exception: 

-- Process 3 terminated with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
    fn(i, *args)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 228, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 220, in _setup_replication
    device = xm.xla_device()
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 146, in xla_device
    devkind=[devkind] if devkind is not None else None)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 50, in get_xla_supported_devices
    xla_devices = torch_xla._XLAC._xla_get_devices()
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:259 : Check failed: impl_->channel->WaitForConnected( std::chrono::system_clock::now() + std::chrono::seconds(connect_wait_seconds)) 
*** Begin stack trace ***
    tensorflow::CurrentStackTrace[abi:cxx11]()
    xla::service::MeshClient::MeshClient(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)
    xla::service::MeshClient::Get()
    xla::ComputationClient::Create()
    xla::ComputationClient::Get()

And here is a full trace: client_mesh_error.txt

The relevant code is here

And, here is the Colab file I am using spawn with nproc=8 I am also using DataLoader, but with num_workers=0 here because using the multiprocessing doesn't seem to like spawn or something, in case that is relevant.

Thanks very much for any help you could provide!

To Reproduce

Steps to reproduce the behavior:

  1. In case this is a known issue, I haven't yet taken the time to pare it down to a minimal example. but you can go to the colab and execute the first three cells to reproduce it.

Expected behavior

The same code and input works properly sometimes, and other times I get this error.

Environment

Additional context

stale[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.