Open RishabhPandit-00 opened 1 year ago
@alanwaketan can you take this one?
Well, XLA doesn't seem to support reduce by default. The way you can simulate a reduce op is to use a functional all_reduce and only let the dst rank use the returned value.
May I ask what's your use case for reduce?
@alanwaketan We want to use it for a custom function that enables the all-gather operation with gradient backward functionality, allowing gradients to be properly computed and synchronized across all ranks in a distributed training scenario
Below is the snippet of the function:
class AllGatherFunc(torch.autograd.Function):
"""AllGather op with gradient backward"""
@staticmethod
def forward(ctx, tensor, *gather_list):
gather_list = list(gather_list)
distributed.all_gather(gather_list, tensor)
return tuple(gather_list)
@staticmethod
def backward(ctx, *grads):
grad_list = list(grads)
rank = distributed.get_rank()
grad_out = grad_list[rank]
dist_ops = [
distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
if i == rank
else distributed.reduce(
grad_list[i], i, distributed.ReduceOp.SUM, async_op=True
)
for i in range(distributed.get_world_size())
]
for _op in dist_ops:
_op.wait()
grad_out *= len(grad_list) # cooperate with distributed loss function
return (grad_out, *[None for _ in range(len(grad_list))])
AllGather = AllGatherFunc.apply
I see. @JackCaoG Have we asked the xla compiler about supporting reduce?
Would https://github.com/pytorch/xla/blob/master/torch_xla/core/functions.py#L6 works?
Good call. I remembered I have seen something similar.
@alanwaketan @JackCaoG
Would https://github.com/pytorch/xla/blob/master/torch_xla/core/functions.py#L6 works?
Not sure about this. Can you pls elaborate? Basically we want to convert the above custom function into Pytorch XLA to be run on Google Cloud TPUs.
The above function belongs to this python file: https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/partial_fc_v2.py
We replaced the torch.distributed.all_reduce and distributed.all_gather with xm.all_reduce and xm.all_gather though we are not able to find the corresponding replacement for distributed.reduce
Basically, can you use all_gather in the functions.py instead of your all_gather in the parital_fc_v3.py?
@alanwaketan, Thanks for the guidance. could you please confirm which all_gather function? partial_fc_v2.py has AllGatherFunc which is used in _list_embeddings = AllGather(local_embeddings, *_gather_embeddings) and there is torch.distributed.all_gather.
@alanwaketan @JackCaoG We modified the code as below, but it takes more time as we are running synchronously and in the above code they have used both async and sync
❓ Questions and Help
I am a bit confused here. Can we use torch_xla.core.xla_model.all_reduce in place of torch.distributed.reduce? If, yes In torch.distributed.reduce we need a rank destination, how to change that if we use torch_xla.core.xla_model.all_reduce?