Open carmocca opened 2 months ago
We use async_op=True to reorder the collective calls in the trace, but that's not supported by Dynamo:
async_op=True
[rank7]: torch._dynamo.exc.Unsupported: CollectiveFunctionRewriteVariable can't support async_op=True for <function all_gather_into_tensor at 0x7f993a1bb6d0> [rank7]: from user code: [rank7]: File "thunder.torch_interpreted_func_94", line 13, in torch_interpreted_func [rank7]: p8 = torch_all_gather_prim_impl(t2, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p8: "FUTURE cuda:7 bf16[2560, 2048]" [rank7]: File "/home/carlos/lightning-thunder/thunder/executors/torchex.py", line 1687, in _all_gather_prim_impl [rank7]: handle: None | torch.distributed.distributed_c10d.Work = torch.distributed.all_gather_into_tensor(
We could support this two ways:
a. Disable fullgraph=True to allow graph-breaks via #281 b. Do not let the torch_compile_ex to fuse through these collective calls.
fullgraph=True
Option b) is what I would recommend.
Do not support this.
Requires #140 to land first.
cc @carmocca @awaelchli @crcrpar @apaz-cli
triage review: (b) makes sense to us, too!
🚀 Feature
Motivation
We use
async_op=True
to reorder the collective calls in the trace, but that's not supported by Dynamo:Pitch
We could support this two ways:
a. Disable
fullgraph=True
to allow graph-breaks via #281 b. Do not let the torch_compile_ex to fuse through these collective calls.Option b) is what I would recommend.
Alternatives
Do not support this.
Additional context
Requires #140 to land first.
cc @carmocca @awaelchli @crcrpar @apaz-cli