Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.09k stars 64 forks source link

Support FSDP and torch.compile #298

Open carmocca opened 2 months ago

carmocca commented 2 months ago

🚀 Feature

Motivation

We use async_op=True to reorder the collective calls in the trace, but that's not supported by Dynamo:

[rank7]: torch._dynamo.exc.Unsupported: CollectiveFunctionRewriteVariable can't support async_op=True for <function all_gather_into_tensor at 0x7f993a1bb6d0>

[rank7]: from user code:
[rank7]:    File "thunder.torch_interpreted_func_94", line 13, in torch_interpreted_func
[rank7]:     p8 = torch_all_gather_prim_impl(t2, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p8: "FUTURE cuda:7 bf16[2560, 2048]"
[rank7]:   File "/home/carlos/lightning-thunder/thunder/executors/torchex.py", line 1687, in _all_gather_prim_impl
[rank7]:     handle: None | torch.distributed.distributed_c10d.Work = torch.distributed.all_gather_into_tensor(

Pitch

We could support this two ways:

a. Disable fullgraph=True to allow graph-breaks via #281 b. Do not let the torch_compile_ex to fuse through these collective calls.

Option b) is what I would recommend.

Alternatives

Do not support this.

Additional context

Requires #140 to land first.

cc @carmocca @awaelchli @crcrpar @apaz-cli

mruberry commented 2 months ago

triage review: (b) makes sense to us, too!