pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.46k stars 467 forks source link

[torch_xla2] Functional collectives #7580

Closed will-cromar closed 2 months ago

will-cromar commented 3 months ago

Aside from the traceable op implementations being cleaner, dynamo rewrites torch.distributed calls into their functional equivalents. E.g.

@torch.compile(backend=my_backend)
def cc(index):
  dist.all_reduce(index)
  return index

generates

opcode         name         target                                args                   kwargs
-------------  -----------  ------------------------------------  ---------------------  --------
placeholder    arg0_1       arg0_1                                ()                     {}
call_function  all_reduce   _c10d_functional.all_reduce.default   (arg0_1, 'sum', '0')   {}
call_function  wait_tensor  _c10d_functional.wait_tensor.default  (all_reduce,)          {}
call_function  copy         aten.copy.default                     (arg0_1, wait_tensor)  {}
output         output       output                                ((copy, copy),)        {}

cc @qihqi

Depends on #7311