Open carlesoctav opened 5 months ago
Hi @carlesoctav, thanks for opening the PR. torch.distributed is not fully supported on TPU v2/v3. See related docs here.
Can you check if torch.distributed.all_gather_object
works on TPU ? If so, we can potentially extend this function to TPUs also.
do you think this part of documentation relevant to your question? i can run the below code on v4-8
import torch
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.experimental import pjrt
# Required for `xla://` init_method and `xla` backend
import torch_xla.distributed.xla_backend
def _all_gather(index: int):
# No need to pass in `rank` or `world_size`
dist.init_process_group('xla', init_method='xla://')
t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(output, t)
xm.mark_step()
print(output)
if __name__ == '__main__':
xmp.spawn(_all_gather)
Note: Although the xla:// init_method is not required on TPU v4, it is still recommended. If you use env://, MASTER_ADDR must be set to IP host that has device 0, which is not always worker 0. The xla:// init_method finds this IP automatically.
Note: For TPU v2/v3, you still need to import torch_xla.experimental.pjrt_backend, as TPU v2/v3 support in torch.distributed is still experimental.
well if i naively change the all_gather
to all_gather_object
, the above code results an error:
*** End stack trace ***
Input tensor is not an XLA tensor: CPULongType
error executing "python test_distributed.py": exit status 1
Hey @carlesoctav , I was talking about this code in accelerate:
def _gpu_gather_object(object: Any):
output_objects = [None for _ in range(PartialState().num_processes)]
torch.distributed.all_gather_object(output_objects, object)
# all_gather_object returns a list of lists, so we need to flatten it
return [x for y in output_objects for x in y]
def gather_object(object: Any):
"""
Recursively gather object in a nested list/tuple/dictionary of objects from all devices.
Args:
object (nested list/tuple/dictionary of picklable object):
The data to gather.
Returns:
The same data structure as `object` with all the objects sent to every device.
"""
if PartialState().distributed_type == DistributedType.XLA:
# replace it by _gpu_gather_object(object)
raise NotImplementedError("gather objects in TPU is not supported")
elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
return _gpu_gather_object(object)
else:
return object
Do you have any plans to implement the above functionality on a TPU (Also is this even possible on TPU ?) ? I'm trying to distribute evaluation using accelerate. Error: