Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.37k stars 3.38k forks source link

[TPU-Colab] RuntimeError: Cannot replicate if number of devices (1) is different from 8 #1703

Closed simonepri closed 4 years ago

simonepri commented 4 years ago

🐛 Bug

When I run trainer.test(model) on a pre-trained model using a Colab TPU instance, the following exception is thrown.

NB: trainer.train(model) works.

Stack trace

Traceback (most recent call last):
  File "run_pl_ged.py", line 217, in <module>
    trainer.test(model)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 958, in test
    self.fit(model)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 777, in fit
    xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 182, in spawn
    start_method=start_method)
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 158, in start_processes
    while not context.join():
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 119, in join
    raise Exception(msg)
Exception: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
    fn(i, *args)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 116, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 109, in _setup_replication
    xm.set_replication(str(device), [str(device)])
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 194, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 181, in xla_replication_devices
    .format(len(local_devices), len(kind_devices)))
RuntimeError: Cannot replicate if number of devices (1) is different from 8

Code sample

import pytorch_lightning as pl

model = MyLightningModule()
trainer = pl.Trainer(num_tpu_cores=8)
model = model.load_from_checkpoint(checkpoint)
model.prepare_data() # See https://github.com/PyTorchLightning/pytorch-lightning/issues/1562
trainer.test(model)

Environment

Colab TPU instance with XLA 1.5

* CUDA:
    - GPU:
    - available:         False
    - version:           None
* Packages:
    - numpy:             1.18.3
    - pyTorch_debug:     False
    - pyTorch_version:   1.5.0a0+ab660ae
    - pytorch-lightning: 0.7.5
    - tensorboard:       2.2.1
    - tqdm:              4.38.0
* System:
    - OS:                Linux
    - architecture:
        - 64bit
        - 
    - processor:         x86_64
    - python:            3.6.9
    - version:           #1 SMP Wed Feb 19 05:26:34 PST 2020

Possibly related: https://github.com/PyTorchLightning/pytorch-lightning/pull/1019

github-actions[bot] commented 4 years ago

Hi! thanks for your contribution!, great first issue!

simonepri commented 4 years ago

Any update? Can I help somehow speeding this up?

ArthDh commented 4 years ago

I was facing the same issue on a Colab TPU instance.

pytorch-lightning==0.7.6
torch==1.6.0a0+246d7bb
torch-xla==1.6+62b4c42
torchvision==0.7.0a0+c2e8a00

Using trainer = pl.Trainer(resume_from_checkpoint=str(best_ckpt), num_tpu_cores=1) followed by: trainer.test(model)

results in:

training on 1 TPU cores
INIT TPU local core: 0, global rank: 0
Exception in device=TPU:0: tensorflow/compiler/xla/xla_client/mesh_service.cc:259 : Check failed: impl_->channel->WaitForConnected( std::chrono::system_clock::now() + std::chrono::seconds(connect_wait_seconds)) 
*** Begin stack trace ***
    tensorflow::CurrentStackTrace[abi:cxx11]()
    xla::service::MeshClient::MeshClient(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)
    xla::service::MeshClient::Get()

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault
    _PyFunction_FastCallDict

    PyObject_Call
    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault
    _PyEval_EvalFrameDefault

    _PyFunction_FastCallDict

    PyObject_Call
    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    PyObject_Call

    Py_Main
    main
    __libc_start_main
    _start
*** End stack trace ***
Failed to connect to client mesh master: 06e59f028bed:60141
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 231, in _start_fn
    fn(gindex, *args)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/distrib_parts.py", line 535, in tpu_train
    self.run_pretrain_routine(model)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 951, in run_pretrain_routine
    torch_xla.core.xla_model.rendezvous("pl.Trainer.run_pretrain_routine")
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 679, in rendezvous
    return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:259 : Check failed: impl_->channel->WaitForConnected( std::chrono::system_clock::now() + std::chrono::seconds(connect_wait_seconds)) 
*** Begin stack trace ***
    tensorflow::CurrentStackTrace[abi:cxx11]()
    xla::service::MeshClient::MeshClient(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)
    xla::service::MeshClient::Get()

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault
    _PyFunction_FastCallDict

    PyObject_Call
    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyFunction_FastCallDict

    PyObject_Call
    _PyEval_EvalFrameDefault

    PyObject_Call
    _PyEval_EvalFrameDefault

    PyObject_Call
    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault
    _PyFunction_FastCallDict

    PyObject_Call
    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    PyObject_Call

    Py_Main
    main
    __libc_start_main
    _start
*** End stack trace ***
Failed to connect to client mesh master: 06e59f028bed:60141
An exception has occurred, use %tb to see the full traceback.
inidhinarayan commented 4 years ago

I was facing the same issue on a Colab TPU instance.

pytorch-lightning==0.7.6
torch==1.6.0a0+246d7bb
torch-xla==1.6+62b4c42
torchvision==0.7.0a0+c2e8a00

Using trainer = pl.Trainer(resume_from_checkpoint=str(best_ckpt), num_tpu_cores=1) followed by: trainer.test(model)

results in:

training on 1 TPU cores ... Failed to connect to client mesh master: 06e59f028bed:60141 An exception has occurred, use %tb to see the full traceback.

Did u get any solution? I am facing the same issue

ArthDh commented 4 years ago

Hi @inidhinarayan, I couldn't find a way around it. You might want to try with the latest repo!

edenlightning commented 4 years ago

@nidhinarayan can you let us know if this is still happening on master?

Borda commented 4 years ago

it shall be fixed on master, feel free to reopen if needed 🐰

lezwon commented 4 years ago

I spent some time debugging this issue. I suspect the problem occurs as lightning loads xla weights back on the device. The weights are saved by the master device xla:1 during training. When reloading, these weights are automatically moved back to xla:1. When this happens the current process automatically acquires only one TPU core and considers it as a TPU device. Any attempt to call xmp.spawn after this will result in the given error.

To fix this issue we need to save weights using xm.save() instead of torch.save. This is will transfer the weights to a cpu device before saving. This issue is related to https://github.com/PyTorchLightning/pytorch-lightning/pull/2726

sapthrishi commented 3 years ago

Am still facing this issue today on Kaggle:

training on 8 TPU cores training on 8 TPU cores Exception in device=TPU:0: Cannot replicate if number of devices (1) is different from 8 Traceback (most recent call last): File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn _setup_replication() Exception in device=TPU:1: Cannot replicate if number of devices (1) is different from 8 File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication xm.set_replication(device, [device]) Traceback (most recent call last): File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 317, in set_replication replication_devices = xla_replication_devices(devices) File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn _setup_replication() RuntimeError: Cannot replicate if number of devices (1) is different from 8 Exception in device=TPU:2: Cannot replicate if number of devices (1) is different from 8 File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication xm.set_replication(device, [device]) File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) Traceback (most recent call last): File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 317, in set_replication replication_devices = xla_replication_devices(devices) File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) RuntimeError: Cannot replicate if number of devices (1) is different from 8 Exception in device=TPU:3: Cannot replicate if number of devices (1) is different from 8 Exception in device=TPU:4: Cannot replicate if number of devices (1) is different from 8 File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn _setup_replication() Traceback (most recent call last): Traceback (most recent call last): File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 317, in set_replication replication_devices = xla_replication_devices(devices) File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 317, in set_replication replication_devices = xla_replication_devices(devices) RuntimeError: Cannot replicate if number of devices (1) is different from 8 File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn _setup_replication() File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn _setup_replication() File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication xm.set_replication(device, [device]) File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 317, in set_replication replication_devices = xla_replication_devices(devices) File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication xm.set_replication(device, [device]) File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication xm.set_replication(device, [device]) File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) RuntimeError: Cannot replicate if number of devices (1) is different from 8 Exception in device=TPU:5: Cannot replicate if number of devices (1) is different from 8 File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) Exception in device=TPU:6: Cannot replicate if number of devices (1) is different from 8 Traceback (most recent call last): RuntimeError: Cannot replicate if number of devices (1) is different from 8 Traceback (most recent call last): File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn _setup_replication() Exception in device=TPU:7: Cannot replicate if number of devices (1) is different from 8 File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) Traceback (most recent call last): File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication xm.set_replication(device, [device]) File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn _setup_replication() File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication xm.set_replication(device, [device]) File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 317, in set_replication replication_devices = xla_replication_devices(devices) File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 317, in set_replication replication_devices = xla_replication_devices(devices) File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) RuntimeError: Cannot replicate if number of devices (1) is different from 8

LisburnLad commented 3 years ago

I'm experiencing the same issue with the [Lightning TPU example notebook] (https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/06-mnist-tpu-training.ipynb), run on Colab.

Both the single TPU core examples work, but when trying to run on 8 cores I get the error:

"RuntimeError: Cannot replicate if number of devices (1) is different from 8"

acharjee07 commented 3 years ago

Having the same issue with a kaggle kernel .

acharjee07 commented 3 years ago

Exception in device=TPU:0: Cannot replicate if number of devices (1) is different from 8 Traceback (most recent call last): Exception in device=TPU:1: Cannot replicate if number of devices (1) is different from 8 File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn _start_fn(index, pf_cfg, fn, args)

lezwon commented 3 years ago

@LisburnLad @Sanjay03079 did you run trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=[5]) before you ran trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=8) ?

talhaanwarch commented 3 years ago

@lezwon i run this trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=8) only and tried to restart kernal also. but it did not work

daniellepintz commented 3 years ago

I am experiencing the same issue in #9712. https://app.circleci.com/pipelines/github/PyTorchLightning/pytorch-lightning/44836/workflows/20f3cc67-3596-4d27-8ecb-c909a3cf6577/jobs/132588/parallel-runs/0/steps/0-119

@Borda any insight here?

ynhuhu commented 3 years ago

The problem feels unsolvable.

satpalsr commented 2 years ago

I confirm the problem still exists.

kaushikb11 commented 2 years ago

@satpalsr Could you share your reproducible script for the bug? I could take a look.

JessicaLopezEspejel commented 2 years ago

Hello, any news? I am facing the same problem.