huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
8.01k stars 979 forks source link

gather objects in TPU is not supported #2858

Open carlesoctav opened 5 months ago

carlesoctav commented 5 months ago

Do you have any plans to implement the above functionality on a TPU (Also is this even possible on TPU ?) ? I'm trying to distribute evaluation using accelerate. Error:

    raise NotImplementedError("gather objects in TPU is not supported")
NotImplementedError: gather objects in TPU is not supported
SunMarc commented 5 months ago

Hi @carlesoctav, thanks for opening the PR. torch.distributed is not fully supported on TPU v2/v3. See related docs here. Can you check if torch.distributed.all_gather_object works on TPU ? If so, we can potentially extend this function to TPUs also.

carlesoctav commented 5 months ago

do you think this part of documentation relevant to your question? i can run the below code on v4-8

import torch
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.experimental import pjrt

# Required for `xla://` init_method and `xla` backend
import torch_xla.distributed.xla_backend

def _all_gather(index: int):
  # No need to pass in `rank` or `world_size`
  dist.init_process_group('xla', init_method='xla://')

  t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
  output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
  dist.all_gather(output, t)

  xm.mark_step()
  print(output)

if __name__ == '__main__':
  xmp.spawn(_all_gather)

Note: Although the xla:// init_method is not required on TPU v4, it is still recommended. If you use env://, MASTER_ADDR must be set to IP host that has device 0, which is not always worker 0. The xla:// init_method finds this IP automatically.

Note: For TPU v2/v3, you still need to import torch_xla.experimental.pjrt_backend, as TPU v2/v3 support in torch.distributed is still experimental.

carlesoctav commented 5 months ago

well if i naively change the all_gather to all_gather_object, the above code results an error:

*** End stack trace ***
Input tensor is not an XLA tensor: CPULongType
error executing "python test_distributed.py": exit status 1
SunMarc commented 5 months ago

Hey @carlesoctav , I was talking about this code in accelerate:

def _gpu_gather_object(object: Any):
    output_objects = [None for _ in range(PartialState().num_processes)]
    torch.distributed.all_gather_object(output_objects, object)
    # all_gather_object returns a list of lists, so we need to flatten it
    return [x for y in output_objects for x in y]

def gather_object(object: Any):
    """
    Recursively gather object in a nested list/tuple/dictionary of objects from all devices.

    Args:
        object (nested list/tuple/dictionary of picklable object):
            The data to gather.

    Returns:
        The same data structure as `object` with all the objects sent to every device.
    """
    if PartialState().distributed_type == DistributedType.XLA:
        # replace it by _gpu_gather_object(object)
        raise NotImplementedError("gather objects in TPU is not supported")
    elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
        return _gpu_gather_object(object)
    else:
        return object