pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

Intermittent multiprocessing error on google cloud TPU #3947

Open alif-munim opened 2 years ago

alif-munim commented 2 years ago

❓ Questions and Help

Hello, I've been trying to run a basic MNIST training example on 8 TPU cores on google cloud, but periodically run into the following errors when running the xmp.spawn function to begin training:

Error 1 (Snippet):

Exception in device=TPU:0: tensorflow/compiler/xla/xla_client/xrt_local_service.cc:56 : Check failed: tensorflow::NewServer(server_def, &server_) == ::tensorflow::Status::OK() (PERMISSION_DENIED: open(/dev/accel0): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel0; Unable to create Node RegisterInterface for node 0, config: device_path:     "/dev/accel0" mode: KERNEL debug_data_directory: "" dump_anomalies_only: true crash_in_debug_dump: false allow_core_dump: true; could not create driver instance vs. OK)
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication
    device = xm.xla_device()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 231, in xla_device
    devices = get_xla_supported_devices(
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 137, in get_xla_supported_devices
    xla_devices = _DEVICES.value
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/utils/utils.py", line 32, in value
    self._value = self._gen_fn()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 19, in <lambda>
    _DEVICES = xu.LazyProperty(lambda: torch_xla._XLAC._xla_get_devices())

Error 2 (Snippet):

Exception in device=TPU:0: Cannot replicate if number of devices (1) is different from 8Exception in device=TPU:1: Cannot replicate if number of devices (1) is different from 8
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 316, in _setup_replication
    xm.set_replication(device, [device])
Exception in device=TPU:2: Cannot replicate if number of devices (1) is different from 8

Strangely, these errors will disappear for a while and the code will run fine, and then suddenly pop back up again. The same code was previously running on 8 cores just over an hour ago. Also worth noting is that training seems to work fine on 1 core, with no errors. I'll include the full error logs below.

Any help would be greatly appreciated!

alif-munim commented 2 years ago

Error 1 (Full):

Exception in device=TPU:0: tensorflow/compiler/xla/xla_client/xrt_local_service.cc:56 : Check failed: tensorflow::NewServer(server_def, &server_) == ::tensorflow::Status::OK() (PERMISSION_DENIED: open(/dev/accel0): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel0; Unable to create Node RegisterInterface for node 0, config: device_path:     "/dev/accel0" mode: KERNEL debug_data_directory: "" dump_anomalies_only: true crash_in_debug_dump: false allow_core_dump: true; could not create driver instance vs. OK)
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication
    device = xm.xla_device()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 231, in xla_device
    devices = get_xla_supported_devices(
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 137, in get_xla_supported_devices
    xla_devices = _DEVICES.value
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/utils/utils.py", line 32, in value
    self._value = self._gen_fn()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 19, in <lambda>
    _DEVICES = xu.LazyProperty(lambda: torch_xla._XLAC._xla_get_devices())
RuntimeError: tensorflow/compiler/xla/xla_client/xrt_local_service.cc:56 : Check failed: tensorflow::NewServer(server_def, &server_) == ::tensorflow::Status::OK() (PERMISSION_DENIED: open(/dev/accel0): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel0; Unable to create Node RegisterInterface for node 0, config: device_path:      "/dev/accel0" mode: KERNEL debug_data_directory: "" dump_anomalies_only: true crash_in_debug_dump: false allow_core_dump: true; could not create driver instance vs. OK)
2022-08-29 16:58:00.378200: E tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:580] PERMISSION_DENIED: open(/dev/accel0): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel0; Unable to create Node RegisterInterface for node 0, config: device_path:    "/dev/accel0" mode: KERNEL debug_data_directory: "" dump_anomalies_only: true crash_in_debug_dump: false allow_core_dump: true; could not create driver instance
https://symbolize.stripped_domain/r/?trace=7fb77c0d246e,7fb77bff608f,7fb679a41d3c,7fb67996ddb0,7fb6799683a6,7fb66d880f28,7fb66d8812b3,7fb66d879f5d,7fb66d87b212,7fb77bfa14de&map=635d802eb3479bd129fe8817c389e8619db396e5:7fb66a7ce000-7fb67cece500 
*** SIGTERM received by PID 7519 (TID 7519) on cpu 16 from PID 7408; stack trace: ***
PC: @     0x7fb77c0d246e  (unknown)  epoll_wait
    @     0x7fb669b462f3        992  (unknown)
    @     0x7fb77bff6090  1251481952  (unknown)
    @     0x7fb679a41d3d        224  cq_next()
    @     0x7fb67996ddb1         64  grpc_impl::CompletionQueue::AsyncNextInternal()
    @     0x7fb6799683a7        176  grpc_impl::Channel::WaitForStateChangeImpl()
    @     0x7fb66d880f29        512  xla::service::MeshClient::MeshClient()
    @     0x7fb66d8812b4         80  xla::service::MeshClient::Get()
    @     0x7fb66d879f5e       1472  xla::ComputationClient::Create()
    @     0x7fb66d87b213         32  std::call_once<>()::{lambda()#2}::_FUN()
    @     0x7fb77bfa14df  (unknown)  __pthread_once_slow
https://symbolize.stripped_domain/r/?trace=7fb77c0d246e,7fb669b462f2,7fb77bff608f,7fb679a41d3c,7fb67996ddb0,7fb6799683a6,7fb66d880f28,7fb66d8812b3,7fb66d879f5d,7fb66d87b212,7fb77bfa14de&map=635d802eb3479bd129fe8817c389e8619db396e5:7fb66a7ce000-7fb67cece500,5920735bb186a93a82e10840f91bc184:7fb655353000-7fb669ec9fb0 
E0829 16:58:00.904083    7519 coredump_hook.cc:320] RAW: Remote crash gathering disabled for SIGTERM.
E0829 16:58:00.926883    7519 process_state.cc:774] RAW: Raising signal 15 with default behavior
https://symbolize.stripped_domain/r/?trace=7fb77c0d246e,7fb77bff608f,7fb679a41d3c,7fb67996ddb0,7fb6799683a6,7fb66d880f28,7fb66d8812b3,7fb66d879f5d,7fb66d87b212,7fb77bfa14de&map=635d802eb3479bd129fe8817c389e8619db396e5:7fb66a7ce000-7fb67cece500 
*** SIGTERM received by PID 7616 (TID 7616) on cpu 0 from PID 7408; stack trace: ***
PC: @     0x7fb77c0d246e  (unknown)  epoll_wait
    @     0x7fb669b462f3        992  (unknown)
    @     0x7fb77bff6090  1251481952  (unknown)
    @     0x7fb679a41d3d        224  cq_next()
    @     0x7fb67996ddb1         64  grpc_impl::CompletionQueue::AsyncNextInternal()
    @     0x7fb6799683a7        176  grpc_impl::Channel::WaitForStateChangeImpl()
    @     0x7fb66d880f29        512  xla::service::MeshClient::MeshClient()
    @     0x7fb66d8812b4         80  xla::service::MeshClient::Get()
    @     0x7fb66d879f5e       1472  xla::ComputationClient::Create()
    @     0x7fb66d87b213         32  std::call_once<>()::{lambda()#2}::_FUN()
    @     0x7fb77bfa14df  (unknown)  __pthread_once_slow
https://symbolize.stripped_domain/r/?trace=7fb77c0d246e,7fb669b462f2,7fb77bff608f,7fb679a41d3c,7fb67996ddb0,7fb6799683a6,7fb66d880f28,7fb66d8812b3,7fb66d879f5d,7fb66d87b212,7fb77bfa14de&map=635d802eb3479bd129fe8817c389e8619db396e5:7fb66a7ce000-7fb67cece500,5920735bb186a93a82e10840f91bc184:7fb655353000-7fb669ec9fb0 
E0829 16:58:01.064157    7616 coredump_hook.cc:320] RAW: Remote crash gathering disabled for SIGTERM.
E0829 16:58:01.087023    7616 process_state.cc:774] RAW: Raising signal 15 with default behavior
https://symbolize.stripped_domain/r/?trace=7fb77c0d246e,7fb77bff608f,7fb679a41d3c,7fb67996ddb0,7fb6799683a6,7fb66d880f28,7fb66d8812b3,7fb66d879f5d,7fb66d87b212,7fb77bfa14de&map=635d802eb3479bd129fe8817c389e8619db396e5:7fb66a7ce000-7fb67cece500 
*** SIGTERM received by PID 7620 (TID 7620) on cpu 17 from PID 7408; stack trace: ***
PC: @     0x7fb77c0d246e  (unknown)  epoll_wait
    @     0x7fb669b462f3        992  (unknown)
    @     0x7fb77bff6090  1251481952  (unknown)
    @     0x7fb679a41d3d        224  cq_next()
    @     0x7fb67996ddb1         64  grpc_impl::CompletionQueue::AsyncNextInternal()
    @     0x7fb6799683a7        176  grpc_impl::Channel::WaitForStateChangeImpl()
    @     0x7fb66d880f29        512  xla::service::MeshClient::MeshClient()
    @     0x7fb66d8812b4         80  xla::service::MeshClient::Get()
    @     0x7fb66d879f5e       1472  xla::ComputationClient::Create()
    @     0x7fb66d87b213         32  std::call_once<>()::{lambda()#2}::_FUN()
    @     0x7fb77bfa14df  (unknown)  __pthread_once_slow
https://symbolize.stripped_domain/r/?trace=7fb77c0d246e,7fb669b462f2,7fb77bff608f,7fb679a41d3c,7fb67996ddb0,7fb6799683a6,7fb66d880f28,7fb66d8812b3,7fb66d879f5d,7fb66d87b212,7fb77bfa14de&map=635d802eb3479bd129fe8817c389e8619db396e5:7fb66a7ce000-7fb67cece500,5920735bb186a93a82e10840f91bc184:7fb655353000-7fb669ec9fb0 
E0829 16:58:01.225164    7620 coredump_hook.cc:320] RAW: Remote crash gathering disabled for SIGTERM.
E0829 16:58:01.248046    7620 process_state.cc:774] RAW: Raising signal 15 with default behavior
https://symbolize.stripped_domain/r/?trace=7fb77c0d246e,7fb77bff608f,7fb679a41d3c,7fb67996ddb0,7fb6799683a6,7fb66d880f28,7fb66d8812b3,7fb66d879f5d,7fb66d87b212,7fb77bfa14de&map=635d802eb3479bd129fe8817c389e8619db396e5:7fb66a7ce000-7fb67cece500 
*** SIGTERM received by PID 7624 (TID 7624) on cpu 49 from PID 7408; stack trace: ***
PC: @     0x7fb77c0d246e  (unknown)  epoll_wait
    @     0x7fb669b462f3        992  (unknown)
    @     0x7fb77bff6090  1251481952  (unknown)
    @     0x7fb679a41d3d        224  cq_next()
    @     0x7fb67996ddb1         64  grpc_impl::CompletionQueue::AsyncNextInternal()
    @     0x7fb6799683a7        176  grpc_impl::Channel::WaitForStateChangeImpl()
    @     0x7fb66d880f29        512  xla::service::MeshClient::MeshClient()
    @     0x7fb66d8812b4         80  xla::service::MeshClient::Get()
    @     0x7fb66d879f5e       1472  xla::ComputationClient::Create()
    @     0x7fb66d87b213         32  std::call_once<>()::{lambda()#2}::_FUN()
    @     0x7fb77bfa14df  (unknown)  __pthread_once_slow
https://symbolize.stripped_domain/r/?trace=7fb77c0d246e,7fb669b462f2,7fb77bff608f,7fb679a41d3c,7fb67996ddb0,7fb6799683a6,7fb66d880f28,7fb66d8812b3,7fb66d879f5d,7fb66d87b212,7fb77bfa14de&map=635d802eb3479bd129fe8817c389e8619db396e5:7fb66a7ce000-7fb67cece500,5920735bb186a93a82e10840f91bc184:7fb655353000-7fb669ec9fb0 
E0829 16:58:01.383841    7624 coredump_hook.cc:320] RAW: Remote crash gathering disabled for SIGTERM.
E0829 16:58:01.406620    7624 process_state.cc:774] RAW: Raising signal 15 with default behavior
https://symbolize.stripped_domain/r/?trace=7fb77c0d246e,7fb77bff608f,7fb679a41d3c,7fb67996ddb0,7fb6799683a6,7fb66d880f28,7fb66d8812b3,7fb66d879f5d,7fb66d87b212,7fb77bfa14de&map=635d802eb3479bd129fe8817c389e8619db396e5:7fb66a7ce000-7fb67cece500 
*** SIGTERM received by PID 7628 (TID 7628) on cpu 14 from PID 7408; stack trace: ***
PC: @     0x7fb77c0d246e  (unknown)  epoll_wait
    @     0x7fb669b462f3        992  (unknown)
    @     0x7fb77bff6090  1251481952  (unknown)
    @     0x7fb679a41d3d        224  cq_next()
    @     0x7fb67996ddb1         64  grpc_impl::CompletionQueue::AsyncNextInternal()
    @     0x7fb6799683a7        176  grpc_impl::Channel::WaitForStateChangeImpl()
    @     0x7fb66d880f29        512  xla::service::MeshClient::MeshClient()
    @     0x7fb66d8812b4         80  xla::service::MeshClient::Get()
    @     0x7fb66d879f5e       1472  xla::ComputationClient::Create()
    @     0x7fb66d87b213         32  std::call_once<>()::{lambda()#2}::_FUN()
    @     0x7fb77bfa14df  (unknown)  __pthread_once_slow
https://symbolize.stripped_domain/r/?trace=7fb77c0d246e,7fb669b462f2,7fb77bff608f,7fb679a41d3c,7fb67996ddb0,7fb6799683a6,7fb66d880f28,7fb66d8812b3,7fb66d879f5d,7fb66d87b212,7fb77bfa14de&map=635d802eb3479bd129fe8817c389e8619db396e5:7fb66a7ce000-7fb67cece500,5920735bb186a93a82e10840f91bc184:7fb655353000-7fb669ec9fb0 
E0829 16:58:01.543524    7628 coredump_hook.cc:320] RAW: Remote crash gathering disabled for SIGTERM.
E0829 16:58:01.566334    7628 process_state.cc:774] RAW: Raising signal 15 with default behavior
https://symbolize.stripped_domain/r/?trace=7fb77c0d246e,7fb77bff608f,7fb679a41d3c,7fb67996ddb0,7fb6799683a6,7fb66d880f28,7fb66d8812b3,7fb66d879f5d,7fb66d87b212,7fb77bfa14de&map=635d802eb3479bd129fe8817c389e8619db396e5:7fb66a7ce000-7fb67cece500 
*** SIGTERM received by PID 7633 (TID 7633) on cpu 23 from PID 7408; stack trace: ***
PC: @     0x7fb77c0d246e  (unknown)  epoll_wait
    @     0x7fb669b462f3        992  (unknown)
    @     0x7fb77bff6090  1251481952  (unknown)
    @     0x7fb679a41d3d        224  cq_next()
    @     0x7fb67996ddb1         64  grpc_impl::CompletionQueue::AsyncNextInternal()
    @     0x7fb6799683a7        176  grpc_impl::Channel::WaitForStateChangeImpl()
    @     0x7fb66d880f29        512  xla::service::MeshClient::MeshClient()
    @     0x7fb66d8812b4         80  xla::service::MeshClient::Get()
    @     0x7fb66d879f5e       1472  xla::ComputationClient::Create()
    @     0x7fb66d87b213         32  std::call_once<>()::{lambda()#2}::_FUN()
    @     0x7fb77bfa14df  (unknown)  __pthread_once_slow
https://symbolize.stripped_domain/r/?trace=7fb77c0d246e,7fb669b462f2,7fb77bff608f,7fb679a41d3c,7fb67996ddb0,7fb6799683a6,7fb66d880f28,7fb66d8812b3,7fb66d879f5d,7fb66d87b212,7fb77bfa14de&map=635d802eb3479bd129fe8817c389e8619db396e5:7fb66a7ce000-7fb67cece500,5920735bb186a93a82e10840f91bc184:7fb655353000-7fb669ec9fb0 
E0829 16:58:01.724528    7633 coredump_hook.cc:320] RAW: Remote crash gathering disabled for SIGTERM.
E0829 16:58:01.747290    7633 process_state.cc:774] RAW: Raising signal 15 with default behavior
https://symbolize.stripped_domain/r/?trace=7fb77c0d246e,7fb77bff608f,7fb679a41d3c,7fb67996ddb0,7fb6799683a6,7fb66d880f28,7fb66d8812b3,7fb66d879f5d,7fb66d87b212,7fb77bfa14de&map=635d802eb3479bd129fe8817c389e8619db396e5:7fb66a7ce000-7fb67cece500 
*** SIGTERM received by PID 7636 (TID 7636) on cpu 6 from PID 7408; stack trace: ***
PC: @     0x7fb77c0d246e  (unknown)  epoll_wait
    @     0x7fb669b462f3        992  (unknown)
    @     0x7fb77bff6090  1251481952  (unknown)
    @     0x7fb679a41d3d        224  cq_next()
    @     0x7fb67996ddb1         64  grpc_impl::CompletionQueue::AsyncNextInternal()
    @     0x7fb6799683a7        176  grpc_impl::Channel::WaitForStateChangeImpl()
    @     0x7fb66d880f29        512  xla::service::MeshClient::MeshClient()
    @     0x7fb66d8812b4         80  xla::service::MeshClient::Get()
    @     0x7fb66d879f5e       1472  xla::ComputationClient::Create()
    @     0x7fb66d87b213         32  std::call_once<>()::{lambda()#2}::_FUN()
    @     0x7fb77bfa14df  (unknown)  __pthread_once_slow
https://symbolize.stripped_domain/r/?trace=7fb77c0d246e,7fb669b462f2,7fb77bff608f,7fb679a41d3c,7fb67996ddb0,7fb6799683a6,7fb66d880f28,7fb66d8812b3,7fb66d879f5d,7fb66d87b212,7fb77bfa14de&map=635d802eb3479bd129fe8817c389e8619db396e5:7fb66a7ce000-7fb67cece500,5920735bb186a93a82e10840f91bc184:7fb655353000-7fb669ec9fb0 
E0829 16:58:01.903827    7636 coredump_hook.cc:320] RAW: Remote crash gathering disabled for SIGTERM.
E0829 16:58:01.926659    7636 process_state.cc:774] RAW: Raising signal 15 with default behavior
---------------------------------------------------------------------------
ProcessExitedException                    Traceback (most recent call last)
Input In [6], in <cell line: 12>()
      8   if rank == 0:
      9     # Retrieve tensors that are on TPU core 0 and plot.
     10     plot_results(data.cpu(), pred.cpu(), target.cpu())
---> 12 xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8,
     13           start_method='fork')

File /usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py:389, in spawn(fn, args, nprocs, join, daemon, start_method)
    387   _start_fn(0, pf_cfg, fn, args)
    388 else:
--> 389   return torch.multiprocessing.start_processes(
    390       _mp_start_fn,
    391       args=(pf_cfg, fn, args),
    392       nprocs=pf_cfg.num_devices,
    393       join=join,
    394       daemon=daemon,
    395       start_method=start_method)

File /usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py:198, in start_processes(fn, args, nprocs, join, daemon, start_method)
    195     return context
    197 # Loop on join until it returns True or raises an exception.
--> 198 while not context.join():
    199     pass

File /usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py:149, in ProcessContext.join(self, timeout)
    140         raise ProcessExitedException(
    141             "process %d terminated with signal %s" %
    142             (error_index, name),
   (...)
    146             signal_name=name
    147         )
    148     else:
--> 149         raise ProcessExitedException(
    150             "process %d terminated with exit code %d" %
    151             (error_index, exitcode),
    152             error_index=error_index,
    153             error_pid=failed_process.pid,
    154             exit_code=exitcode
    155         )
    157 original_trace = self.error_queues[error_index].get()
    158 msg = "\n\n-- Process %d terminated with the following error:\n" % error_index

ProcessExitedException: process 0 terminated with exit code 17
alif-munim commented 2 years ago

Error 2 (Full):

Exception in device=TPU:0: Cannot replicate if number of devices (1) is different from 8Exception in device=TPU:1: Cannot replicate if number of devices (1) is different from 8
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 316, in _setup_replication
    xm.set_replication(device, [device])
Exception in device=TPU:2: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 318, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 286, in xla_replication_devices
    raise RuntimeError(
Traceback (most recent call last):
RuntimeError: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 316, in _setup_replication
    xm.set_replication(device, [device])
Exception in device=TPU:3: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 318, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 286, in xla_replication_devices
    raise RuntimeError(
Traceback (most recent call last):
RuntimeError: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 316, in _setup_replication
    xm.set_replication(device, [device])
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 318, in set_replication
    replication_devices = xla_replication_devices(devices)
Exception in device=TPU:4: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 286, in xla_replication_devices
    raise RuntimeError(
RuntimeError: Cannot replicate if number of devices (1) is different from 8
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 316, in _setup_replication
    xm.set_replication(device, [device])
Exception in device=TPU:5: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 318, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 286, in xla_replication_devices
    raise RuntimeError(
Traceback (most recent call last):
RuntimeError: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 316, in _setup_replication
    xm.set_replication(device, [device])
Exception in device=TPU:6: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 318, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 286, in xla_replication_devices
    raise RuntimeError(
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 316, in _setup_replication
    xm.set_replication(device, [device])
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 318, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 286, in xla_replication_devices
    raise RuntimeError(
RuntimeError: Cannot replicate if number of devices (1) is different from 8
Exception in device=TPU:7: Cannot replicate if number of devices (1) is different from 8
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 316, in _setup_replication
    xm.set_replication(device, [device])
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 318, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 286, in xla_replication_devices
    raise RuntimeError(
RuntimeError: Cannot replicate if number of devices (1) is different from 8

Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 316, in _setup_replication
    xm.set_replication(device, [device])
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 318, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 286, in xla_replication_devices
    raise RuntimeError(
RuntimeError: Cannot replicate if number of devices (1) is different from 8
https://symbolize.stripped_domain/r/?trace=7f67f3e2a99f,7f67f3d5b08f,7f67f275f3a2,7f67f273c562,7f67f27b3d8b,7f67f27b3fef,7f67f21e2a37,2554c0f&map=f814e6f344aa5a0a982b6c297b480edf07a2da37:7f67f2722000-7f67f280704e,d26e5c91fac6e647d247576495f39239c07306f3:7f67f21d8000-7f67f21e6574 
*** SIGTERM received by PID 9328 (TID 9328) on cpu 35 from PID 4342; stack trace: ***
PC: @     0x7f67f3e2a99f  (unknown)  poll
    @     0x7f66e0b6b2f3        992  (unknown)
    @     0x7f67f3d5b090  506136816  (unknown)
    @     0x7f67f275f3a3         80  (unknown)
    @     0x7f67f273c563        336  (unknown)
    @     0x7f67f27b3d8c         48  zmq_ctx_term
    @     0x7f67f27b3ff0         32  zmq_ctx_destroy
    @     0x7f67f21e2a38  (unknown)  (unknown)
    @          0x2554c10  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f67f3e2a99f,7f66e0b6b2f2,7f67f3d5b08f,7f67f275f3a2,7f67f273c562,7f67f27b3d8b,7f67f27b3fef,7f67f21e2a37,2554c0f&map=f814e6f344aa5a0a982b6c297b480edf07a2da37:7f67f2722000-7f67f280704e,d26e5c91fac6e647d247576495f39239c07306f3:7f67f21d8000-7f67f21e6574,5920735bb186a93a82e10840f91bc184:7f66cc378000-7f66e0eeefb0 
E0829 17:23:57.051959    9328 coredump_hook.cc:320] RAW: Remote crash gathering disabled for SIGTERM.
https://symbolize.stripped_domain/r/?trace=7f66e0bc8643,7f67f3d5b08f,7f66e0a72306,7f66e0b81044,7f66e0b86bcc,7f66e0b85772,7f66e0b85229,7f66e0ee814d,7f66e0b6bd8a,7f67f3d5b08f,7f67f275f3a2,7f67f273c562,7f67f27b3d8b,7f67f27b3fef,7f67f21e2a37,2554c0f&map=f814e6f344aa5a0a982b6c297b480edf07a2da37:7f67f2722000-7f67f280704e,d26e5c91fac6e647d247576495f39239c07306f3:7f67f21d8000-7f67f21e6574,5920735bb186a93a82e10840f91bc184:7f66cc378000-7f66e0eeefb0 
E0829 17:23:57.052688    9328 process_state.cc:1067] RAW: Signal 11 raised at PC: 0x7f66e0bc8643 while already in FailureSignalHandler!
E0829 17:23:57.052697    9328 process_state.cc:1102] RAW: Raising 11 signal with default behavior
---------------------------------------------------------------------------
ProcessExitedException                    Traceback (most recent call last)
File <timed eval>:1, in <module>

File /usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py:389, in spawn(fn, args, nprocs, join, daemon, start_method)
    387   _start_fn(0, pf_cfg, fn, args)
    388 else:
--> 389   return torch.multiprocessing.start_processes(
    390       _mp_start_fn,
    391       args=(pf_cfg, fn, args),
    392       nprocs=pf_cfg.num_devices,
    393       join=join,
    394       daemon=daemon,
    395       start_method=start_method)

File /usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py:198, in start_processes(fn, args, nprocs, join, daemon, start_method)
    195     return context
    197 # Loop on join until it returns True or raises an exception.
--> 198 while not context.join():
    199     pass

File /usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py:149, in ProcessContext.join(self, timeout)
    140         raise ProcessExitedException(
    141             "process %d terminated with signal %s" %
    142             (error_index, name),
   (...)
    146             signal_name=name
    147         )
    148     else:
--> 149         raise ProcessExitedException(
    150             "process %d terminated with exit code %d" %
    151             (error_index, exitcode),
    152             error_index=error_index,
    153             error_pid=failed_process.pid,
    154             exit_code=exitcode
    155         )
    157 original_trace = self.error_queues[error_index].get()
    158 msg = "\n\n-- Process %d terminated with the following error:\n" % error_index

ProcessExitedException: process 0 terminated with exit code 17
alif-munim commented 2 years ago

Found some additional information on the pytorch lightning docs section on TPUs, which mentions that you should not call xm.xla_device() outside of the spawn process. I've removed that line, and the training now seems to begin, but I get the following errors:

Error 3 (Snippet):

Exception in device=TPU:1: tensorflow/compiler/xla/xla_client/mesh_service.cc:329 : Check failed: impl_->channel->WaitForConnected( std::chrono::system_clock::now() + std::chrono::seconds(connect_wait_seconds)) 

Error 4 (Snippet):

2022-08-29 20:10:46.436423: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:157] RPC failed with status = "UNAVAILABLE: Socket closed" and grpc_error_string = "{"created":"@1661803846.436277160","description":"Error received from peer ipv4:127.0.0.1:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC
alif-munim commented 2 years ago

Error 3 (Full):

Exception in device=TPU:1: tensorflow/compiler/xla/xla_client/mesh_service.cc:329 : Check failed: impl_->channel->WaitForConnected( std::chrono::system_clock::now() + std::chrono::seconds(connect_wait_seconds)) 
*** Begin stack trace ***
    tensorflow::CurrentStackTrace()
    xla::service::MeshClient::MeshClient(std::string const&)
    xla::service::MeshClient::Get()
    xla::ComputationClient::Create()

    xla::ComputationClient::Get()

    PyCFunction_Call
    _PyObject_MakeTpCall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault

    PyObject_GetAttr
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    PyObject_Call
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    PyObject_Call
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName

    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault

    _PyObject_MakeTpCall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    PyEval_EvalCode

    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall

    PyObject_Call
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    PyEval_EvalCode

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall

    PyObject_Call
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyObject_MakeTpCall

    PyVectorcall_Call
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName

    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    PyEval_EvalCode

    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    PyObject_Call
*** End stack trace ***
Failed to connect to client mesh master: t1v-n-c21ceeb4-w-0.us-central1-b.c.conversational-ai-project.internal:54271
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication
    device = xm.xla_device()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 231, in xla_device
    devices = get_xla_supported_devices(
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 137, in get_xla_supported_devices
    xla_devices = _DEVICES.value

The above error is repeated for each TPU device.

alif-munim commented 2 years ago

Error 4 (Full):

2022-08-29 20:10:46.436423: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:157] RPC failed with status = "UNAVAILABLE: Socket closed" and grpc_error_string = "{"created":"@1661803846.436277160","description":"Error received from peer ipv4:127.0.0.1:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC

The following error is repeated multiple times throughout training. The progress and training completion print statements are also printed multiple times. It seems like the model is being trained separately across the cores. The training cell block gets stuck after the completion messages are printed.

JackCaoG commented 2 years ago

Thanks for reporting, before we dive in too much of the error, can I get some information first?

  1. Are you using TPUVM or TPU Node(through colab, kaggle etc) You can check https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm for more detial
  2. Which version of pytorch and pytorch/xla are you using
  3. How did you create the TPU (what image)
alif-munim commented 2 years ago

Sure! I think @MarkCoatsworth could provide some more information about our setup

fffffgggg54 commented 2 years ago

I am experiencing the (PERMISSION_DENIED: open(/dev/accel0) while using Pytorch 1.12 with a TPU VM. In my experience, the error is caused by interrupting the program after the model and batches have been loaded to the device. Restarting the vm resolves the issue consistently.

sudo reboot -> reconnect -> run program: no error sudo reboot -> reconnect -> run program -> Ctrl+C after calling .to(device) -> run program: error sudo reboot -> reconnect -> run program -> Ctrl+C after calling .to(device) -> sudo reboot -> reconnect -> run program: no error

JackCaoG commented 2 years ago

I am experiencing the (PERMISSION_DENIED: open(/dev/accel0) while using Pytorch 1.12 with a TPU VM. In my experience, the error is caused by interrupting the program after the model and batches have been loaded to the device. Restarting the vm resolves the issue consistently.

sudo reboot -> reconnect -> run program: no error sudo reboot -> reconnect -> run program -> Ctrl+C after calling .to(device) -> run program: error sudo reboot -> reconnect -> run program -> Ctrl+C after calling .to(device) -> sudo reboot -> reconnect -> run program: no error

maybe just pkill -f python? It should clean up the dangling process that holds lock to the device.

markcoatsworth commented 2 years ago

Hi @JackCaoG sorry for the slow reply! We're using TPU VMs. I'm not sure if it's important, but we created them from the gcloud CLI instead of the web console, using the following command:

gcloud compute tpus tpu-vm create <tpu-name> --zone=us-central1-b --accelerator-type=v2-8 --version=tpu-vm-pt-1.12 --scopes=https://www.googleapis.com/auth/cloud-platform

We're using pytorch and pytorch-xla 1.12, using the tpu-vm-pt-1.12 image as indicated above. Let me know if there's anything else.

JackCaoG commented 2 years ago

2022-08-29 20:10:46.436423: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:157] RPC failed with status = "UNAVAILABLE: Socket closed" and grpc_error_string = "{"created":"@1661803846.436277160","description":"Error received from peer ipv4:127.0.0.1:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC

suggested that one of the process(core) crashed and others can't reach it. You encounter a couple error above, were you able to get it to run eventually? Also are you using Pytorch-lighting?

For a sanity test purpose, can you run our resnet test after

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd xla/
python test/test_train_mp_imagenet.py --fake_data

and see if it will run? I am trying to figure out if it is a model code issue or infra/config issue. Default config you should use is export XRT_TPU_CONFIG='localservice;0;localhost:51011'