pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 482 forks source link

torch.distributed.reduce vs torch_xla.core.xla_model.all_reduce #5022

Open RishabhPandit-00 opened 1 year ago

RishabhPandit-00 commented 1 year ago

❓ Questions and Help

I am a bit confused here. Can we use torch_xla.core.xla_model.all_reduce in place of torch.distributed.reduce? If, yes In torch.distributed.reduce we need a rank destination, how to change that if we use torch_xla.core.xla_model.all_reduce?

JackCaoG commented 1 year ago

@alanwaketan can you take this one?

alanwaketan commented 1 year ago

Well, XLA doesn't seem to support reduce by default. The way you can simulate a reduce op is to use a functional all_reduce and only let the dst rank use the returned value.

May I ask what's your use case for reduce?

DevanshiParmar commented 1 year ago

@alanwaketan We want to use it for a custom function that enables the all-gather operation with gradient backward functionality, allowing gradients to be properly computed and synchronized across all ranks in a distributed training scenario

Below is the snippet of the function:

class AllGatherFunc(torch.autograd.Function):
    """AllGather op with gradient backward"""

    @staticmethod
    def forward(ctx, tensor, *gather_list):
        gather_list = list(gather_list)
        distributed.all_gather(gather_list, tensor)
        return tuple(gather_list)

    @staticmethod
    def backward(ctx, *grads):
        grad_list = list(grads)
        rank = distributed.get_rank()
        grad_out = grad_list[rank]

        dist_ops = [
            distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
            if i == rank
            else distributed.reduce(
                grad_list[i], i, distributed.ReduceOp.SUM, async_op=True
            )
            for i in range(distributed.get_world_size())
        ]
        for _op in dist_ops:
            _op.wait()

        grad_out *= len(grad_list)  # cooperate with distributed loss function
        return (grad_out, *[None for _ in range(len(grad_list))])

AllGather = AllGatherFunc.apply 
alanwaketan commented 1 year ago

I see. @JackCaoG Have we asked the xla compiler about supporting reduce?

JackCaoG commented 1 year ago

Would https://github.com/pytorch/xla/blob/master/torch_xla/core/functions.py#L6 works?

alanwaketan commented 1 year ago

Would https://github.com/pytorch/xla/blob/master/torch_xla/core/functions.py#L6 works?

Good call. I remembered I have seen something similar.

DevanshiParmar commented 1 year ago

@alanwaketan @JackCaoG

Would https://github.com/pytorch/xla/blob/master/torch_xla/core/functions.py#L6 works?

Not sure about this. Can you pls elaborate? Basically we want to convert the above custom function into Pytorch XLA to be run on Google Cloud TPUs.

The above function belongs to this python file: https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/partial_fc_v2.py

We replaced the torch.distributed.all_reduce and distributed.all_gather with xm.all_reduce and xm.all_gather though we are not able to find the corresponding replacement for distributed.reduce

alanwaketan commented 1 year ago

Basically, can you use all_gather in the functions.py instead of your all_gather in the parital_fc_v3.py?

RishabhPandit-00 commented 1 year ago

@alanwaketan, Thanks for the guidance. could you please confirm which all_gather function? partial_fc_v2.py has AllGatherFunc which is used in _list_embeddings = AllGather(local_embeddings, *_gather_embeddings) and there is torch.distributed.all_gather.

DevanshiParmar commented 1 year ago

@alanwaketan @JackCaoG We modified the code as below, but it takes more time as we are running synchronously and in the above code they have used both async and sync

image