Open humzaiqbal opened 1 year ago
can you try to broadcast a XLA tensor?
Ah interesting so if I do something like
t = torch.randn(2, 2, device=xm.xla_device())
src = 0
dist.broadcast(t, src=src)
that works but if I do
t = torch.randn(2, 2, device=xm.xla_device())
broadcast_objects = [t]
src = 0
dist.broadcast_object_list(broadcast_objects, src=src)
it fails with the same error. So it seems like there is an issue with the broadcast_object_list method it looks like. My understanding of the method per the description is that any picklable object can be broadcast so I don't think I'm feeding bad input.
Additionally trying this
t = torch.randn(2, 2, device=xm.xla_device())
broadcast_objects = [t]
src = 0
dist.broadcast_object_list(broadcast_objects, src=src, device=xm.xla_device())
leads to the following output
Broadcast successful
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
File "/usr/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
return [fn(*args) for args in chunk]
File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
return [fn(*args) for args in chunk]
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 245, in _run_thread_per_device
replica_results = list(
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 238, in _thread_fn
return fn()
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 341, in __call__
self.fn(global_ordinal(), *self.args, **self.kwargs)
File "/home/ubuntu/simple_broadcast_test.py", line 16, in broadcast_function
dist.broadcast_object_list(broadcast_objects, src=src, device=xm.xla_device())
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1451, in wrapper
return func(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2277, in broadcast_object_list
object_list[i] = _tensor_to_object(obj_view, obj_size)
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1970, in _tensor_to_object
return _unpickler(io.BytesIO(buf)).load()
EOFError: Ran out of input
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "simple_broadcast_test.py", line 22, in <module>
xmp.spawn(broadcast_function, args=(), nprocs=8)
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 386, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 365, in spawn
_run_multiprocess(spawn_fn, start_method=start_method)
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
replica_results = list(
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch_xla/experimental/pjrt.py", line 323, in <genexpr>
itertools.chain.from_iterable(
File "/usr/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
for element in iterable:
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
EOFError: Ran out of input
At Lightning, we worked around these limitations by serializing objects as byte buffers: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/strategies/xla.py#L193-L207
Experiencing this as well
Let me take a look later today.
Ok I think we can do what lighting do in https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/strategies/xla.py#L214-L246 which is also what we do in mesh_reduce
in https://github.com/pytorch/xla/blob/2cfa13cad6fba184262d1b79cd79883a8199881e/torch_xla/core/xla_model.py#L1411-L1433
I think this is one of those usability issues. @zpcore do you think you will have bandwidth to pick this one up?
Thanks, I will pick this up in Q3.
🐛 Bug
Calling broadcast object list with PJRT backend for XLA causes error
To Reproduce
Run this code snippet
The code gives the following error
Expected behavior
The code would execute without issue
Environment