Move collective op implementations out of ProcessGroup and into registered __torch_dispatch__ ops.
Rewrite ProcessGroup implementation using functional collectives
torch.distributed dispatches to c10d.* (e.g. torch.ops.c10d.allreduce_) instead of the functional op directly. he signatures of c10d.* differ from _c10d_functional.*_inplace differ somewhat (str reduce op vs enum, tensor vs list of tensor), so we can't just reuse the exact same function.
When I tried to implement c10d.* directly, all non-tensor objects were wrapped in ScriptObject, causing a bunch of errors. I could not figure out how to unwrap them.
Some functional collectives ignore the process group given a group name. Others look up e.g. the world size by the PG group name in the background (namely all_gather). It's unclear if that's the long-term intended behavior. Let's keep the ProcessGroup for now.
Aside from the traceable op implementations being cleaner, dynamo rewrites torch.distributed calls into their functional equivalents. E.g.
@torch.compile(backend=my_backend)
def cc(index):
dist.all_reduce(index)
return index
ProcessGroup
and into registered__torch_dispatch__
ops.torch.distributed
dispatches toc10d.*
(e.g.torch.ops.c10d.allreduce_
) instead of the functional op directly. he signatures ofc10d.*
differ from_c10d_functional.*_inplace
differ somewhat (str reduce op vs enum, tensor vs list of tensor), so we can't just reuse the exact same function.c10d.*
directly, all non-tensor objects were wrapped inScriptObject
, causing a bunch of errors. I could not figure out how to unwrap them.all_gather
). It's unclear if that's the long-term intended behavior. Let's keep the ProcessGroup for now.Aside from the traceable op implementations being cleaner, dynamo rewrites
torch.distributed
calls into their functional equivalents. E.g.generates
cc @qihqi
Depends on #7311