Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.19k stars 80 forks source link

adding DDP/FSDP transform after JITting does not work #94

Open carmocca opened 7 months ago

carmocca commented 7 months ago

🐛 Bug

The snippet below looks hacky, but it's how I'm approaching support for having the user control the thunder.jit call outside of Fabric: https://github.com/Lightning-AI/litgpt/pull/1204

The objective is that fsdp|ddp can be applied after the thunder.jit call.

It works with FSDP, but not with DDP where it fails with:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/carlos/lightning-thunder/kk.py", line 21, in <module>
[rank1]:     out = tmodel(x)
[rank1]:   File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 194, in forward
[rank1]:     res = self._forward_fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 629, in fn_
[rank1]:     cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 262, in cache_info_wrapper
[rank1]:     res = fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 571, in get_computation_and_inputs
[rank1]:     computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/executors/torch_autograd.py", line 283, in split_forward_backward
[rank1]:     bw_trace = optimize_allreduce_in_ddp_backward(bw_trace, compile_data)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/transforms/ddp.py", line 198, in optimize_allreduce_in_ddp_backward
[rank1]:     updated_bwd_trace = visitor_transform(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/transforms.py", line 368, in visitor_transform
[rank1]:     visit_type = visit(bsym)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/transforms/ddp.py", line 133, in __call__
[rank1]:     self.gradient_buckets.tell(grads_of_bsym[0], self.process_group)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/bucketing.py", line 150, in tell
[rank1]:     self._maybe_allreduce(bucket, group)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/bucketing.py", line 138, in _maybe_allreduce
[rank1]:     self.bucket_to_future[bucket] = dist_prims.all_reduce(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/symbol.py", line 246, in __call__
[rank1]:     result = self.meta(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/langctxs.py", line 124, in _fn
[rank1]:     result = fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/prims.py", line 87, in all_reduce_meta
[rank1]:     utils.check_type(group, torch.distributed.ProcessGroup)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/baseutils.py", line 107, in check_type
[rank1]:     check(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/baseutils.py", line 103, in check
[rank1]:     raise exception_type(s())
[rank1]: ValueError: None had an unexpected type <class 'NoneType'>. Supported types are <class 'torch.distributed.distributed_c10d.ProcessGroup'>

To Reproduce

import os
import thunder
import torch
import torch.distributed as torch_dist

world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0))
if world_size > 1:
    torch_dist.init_process_group(backend="nccl")
    pg = torch_dist.distributed_c10d._get_default_group()
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)

model = torch.nn.Linear(5, 10, bias=False, device=device)
x = torch.randn(2, 5, device=device)

tmodel = thunder.jit(model)
tmodel._lc_cd.fn = thunder.distributed.ddp(tmodel._lc_cd.fn)

out = tmodel(x)

if local_rank == 0:
    print(thunder.last_backward_traces(tmodel)[-1].python())

torchrun --nproc-per-node 2 bug.py

cc @carmocca @awaelchli @crcrpar @kshitij12345 since you fixed a similar issue in #23

carmocca commented 7 months ago

I can work around this by setting

tmodel._lc_cd.process_group_for_ddp = tmodel._lc_cd.fn.process_group_for_ddp

since thunder gets this information at jit() time: https://github.com/Lightning-AI/lightning-thunder/blob/94c94948b79875ba5247b5c986afa088d970a49d/thunder/common.py#L224-L226

So my question is: could we delay accessing this attribute until the function runs?

t-vi commented 7 months ago

TBH, this is a very clear "don't do this, chaning the fn is completely unsupported!".

That said, we can talk about distributed-after-jit. The obstacles are:

I'll chat you up for understanding the need better.

mruberry commented 7 months ago

triage review:

let's look at our current transforms, when they have to be applied, and what they mean when ordered after each other (are all orders supported)? For example, ddp jit grad? jit grad ddp?

do we need to support transforms that change the original module, or maybe produce a new module?

carmocca commented 7 months ago

Currently the ddp transformation is applied during the JITing (i.e. while the interpreter runs). This is fundamentally incompatible with what you're trying to do.

I would appreciate some pointers or examples here that show this, because in my test, the trace does look correct as it contains the appropriate collectives added.

I'm probably misunderstanding how the interpreter works. How can the prologues be generated at jit time if we don't have any input tensors for which to check shapes? I thought this would only happen on the first call

IvanYashchuk commented 7 months ago

FSDP and DDP calls are not trace transforms, they are parameter annotators of the original to-be-jitted PyTorch module.

carmocca commented 6 months ago

We still need to support jit(ddp(model)), as this is basically what happens whenever you jit a function and not the model.

What I'm advocating for is something like jit(ddp(undo_jit(jit(model)))

Where undo_jit is currently the hack that I describe in the top-post.

Allowing this is convenient because then the user can control the innermost jit(model) call but the framework (fabric, trainer) can control the transforms applied to the model and how they interact with each other if there are more than one.

IvanYashchuk commented 6 months ago

I know nothing about Lightning. Do you want to allow users to do jit(model) and then inside Lightning, you apply either DDP or FSDP call to a given model? FSDP is now working, right? You need something that unwraps the jit call. Have you tried using __wrapped__? thunder.jit uses functools.wraps here: https://github.com/Lightning-AI/lightning-thunder/blob/6c64fb93b04672b731180afc4b63d5df55dae92f/thunder/__init__.py#L642

kshitij12345 commented 6 months ago

One thing (probably tangential) I was wondering, why is process_group_for_ddp an attribute for CompileData?

https://github.com/Lightning-AI/lightning-thunder/blob/6c64fb93b04672b731180afc4b63d5df55dae92f/thunder/common.py#L221-L223

I think it would make sense to make it a property. Cause, if a scenario comes where we have to update CompileData.fn, then we might miss updating these corresponding attributes. (This change could potentially also fix the issue)

diff --git a/thunder/common.py b/thunder/common.py
index 85775ff..24cabcb 100644
--- a/thunder/common.py
+++ b/thunder/common.py
@@ -218,10 +218,6 @@ class CompileData:

         self.is_module = isinstance(self.fn, torch.nn.Module)

-        # We set the process_group_for_ddp attribute on the module when
-        # thunder.distributed.ddp(module) is called.
-        self.process_group_for_ddp = getattr(self.fn, "process_group_for_ddp", None)
-
         #
         # Possibly processes the function
         #
@@ -232,6 +228,12 @@ class CompileData:

         assert disable_preprocessing, "please use thunder.compile if you need preprocessing"

+    @property
+    def process_group_for_ddp(self):
+        # We set the process_group_for_ddp attribute on the module when
+        # thunder.distributed.ddp(module) is called.
+        return getattr(self.fn, "process_group_for_ddp", None)
+