pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 478 forks source link

Cannot broadcast object list for XLA with PJRT backend #5492

Open humzaiqbal opened 1 year ago

humzaiqbal commented 1 year ago

🐛 Bug

Calling broadcast object list with PJRT backend for XLA causes error

To Reproduce

Run this code snippet

import torch
import torch.distributed as dist
import torch_xla.experimental.pjrt_backend
import torch_xla.experimental.pjrt as pjrt
import torch_xla.distributed.xla_multiprocessing as xmp
import os

def broadcast_function(dummy_arg):
    os.environ['PJRT_DEVICE'] = 'TPU'
    dist.init_process_group('xla', init_method='pjrt://')
    broadcast_objects = ["2023"]
    src = 0
    dist.broadcast_object_list(broadcast_objects, src=src)
    print("Broadcast successful")

if __name__ == "__main__":
    xmp.spawn(broadcast_function, args=(), nprocs=8)

The code gives the following error

concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
    return fn(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 245, in _run_thread_per_device
    replica_results = list(
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 238, in _thread_fn
    return fn()
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 341, in __call__
    self.fn(global_ordinal(), *self.args, **self.kwargs)
  File "/home/ubuntu/simple_broadcast_test.py", line 14, in broadcast_function
    dist.broadcast_object_list(broadcast_objects, src=src)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1451, in wrapper
    return func(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2255, in broadcast_object_list
    broadcast(object_sizes_tensor, src=src, group=group)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1451, in wrapper
    return func(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1566, in broadcast
    work = default_pg.broadcast([tensor], opts)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/distributed/xla_backend.py", line 94, in broadcast
    xm.collective_broadcast([root_tensor],
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/core/xla_model.py", line 798, in collective_broadcast
    xscale = send_cpu_data_to_device(scale, tensor.device)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/core/xla_model.py", line 1092, in send_cpu_data_to_device
    return ToXlaTensorArena(convert_fn, select_fn).transform(data)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/core/xla_model.py", line 431, in transform
    self._convert()
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/core/xla_model.py", line 403, in _convert
    self._converted_tensors = self._convert_fn(self._tensors)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/core/xla_model.py", line 1087, in convert_fn
    return torch_xla._XLAC._xla_tensors_from_aten(tensors, devices)
RuntimeError: /pytorch/xla/torch_xla/csrc/aten_xla_bridge.cpp:276 : Check failed: device.type() == at::kXLA (cpu vs. xla)
*** Begin stack trace ***
    tsl::CurrentStackTrace()
    torch_xla::bridge::AtenDeviceToXlaDevice(c10::Device const&)

    PyCFunction_Call
    _PyObject_MakeTpCall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    PyObject_Call
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    PyObject_Call
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    PyObject_Call
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall

    PyObject_Call

    _PyObject_MakeTpCall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    PyObject_Call
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    PyObject_Call
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall

    PyObject_Call

    clone
*** End stack trace ***
cpu
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "simple_broadcast_test.py", line 20, in <module>
    xmp.spawn(broadcast_function, args=(), nprocs=8)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 386, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 365, in spawn
    _run_multiprocess(spawn_fn, start_method=start_method)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
    return fn(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
    replica_results = list(
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 323, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
RuntimeError: /pytorch/xla/torch_xla/csrc/aten_xla_bridge.cpp:276 : Check failed: device.type() == at::kXLA (cpu vs. xla)
*** Begin stack trace ***
    tsl::CurrentStackTrace()
    torch_xla::bridge::AtenDeviceToXlaDevice(c10::Device const&)

    PyCFunction_Call
    _PyObject_MakeTpCall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    PyObject_Call
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    PyObject_Call
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    PyObject_Call
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall

    PyObject_Call

    _PyObject_MakeTpCall
    _PyEval_EvalFrameDefault
    _PyEval_EvalCodeWithName
    _PyFunction_Vectorcall
    PyObject_Call
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    PyObject_Call
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall
    _PyEval_EvalFrameDefault
    _PyFunction_Vectorcall

    PyObject_Call

    clone
*** End stack trace ***
cpu

Expected behavior

The code would execute without issue

Environment

JackCaoG commented 1 year ago

can you try to broadcast a XLA tensor?

humzaiqbal commented 1 year ago

Ah interesting so if I do something like

t = torch.randn(2, 2, device=xm.xla_device())
src = 0
dist.broadcast(t, src=src)

that works but if I do

t = torch.randn(2, 2, device=xm.xla_device())
broadcast_objects = [t]
src = 0
dist.broadcast_object_list(broadcast_objects, src=src)

it fails with the same error. So it seems like there is an issue with the broadcast_object_list method it looks like. My understanding of the method per the description is that any picklable object can be broadcast so I don't think I'm feeding bad input.

humzaiqbal commented 1 year ago

Additionally trying this

t = torch.randn(2, 2, device=xm.xla_device())
broadcast_objects = [t]
src = 0
dist.broadcast_object_list(broadcast_objects, src=src, device=xm.xla_device())

leads to the following output

Broadcast successful
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
    return fn(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 245, in _run_thread_per_device
    replica_results = list(
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 238, in _thread_fn
    return fn()
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 341, in __call__
    self.fn(global_ordinal(), *self.args, **self.kwargs)
  File "/home/ubuntu/simple_broadcast_test.py", line 16, in broadcast_function
    dist.broadcast_object_list(broadcast_objects, src=src, device=xm.xla_device())
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1451, in wrapper
    return func(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2277, in broadcast_object_list
    object_list[i] = _tensor_to_object(obj_view, obj_size)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1970, in _tensor_to_object
    return _unpickler(io.BytesIO(buf)).load()
EOFError: Ran out of input
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "simple_broadcast_test.py", line 22, in <module>
    xmp.spawn(broadcast_function, args=(), nprocs=8)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 386, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 365, in spawn
    _run_multiprocess(spawn_fn, start_method=start_method)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
    return fn(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
    replica_results = list(
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 323, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
EOFError: Ran out of input
carmocca commented 1 year ago

At Lightning, we worked around these limitations by serializing objects as byte buffers: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/strategies/xla.py#L193-L207

BitPhinix commented 4 months ago

Experiencing this as well

JackCaoG commented 4 months ago

Let me take a look later today.

JackCaoG commented 4 months ago

Ok I think we can do what lighting do in https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/strategies/xla.py#L214-L246 which is also what we do in mesh_reduce in https://github.com/pytorch/xla/blob/2cfa13cad6fba184262d1b79cd79883a8199881e/torch_xla/core/xla_model.py#L1411-L1433

I think this is one of those usability issues. @zpcore do you think you will have bandwidth to pick this one up?

zpcore commented 4 months ago

Thanks, I will pick this up in Q3.