Open carmocca opened 7 months ago
I can work around this by setting
tmodel._lc_cd.process_group_for_ddp = tmodel._lc_cd.fn.process_group_for_ddp
since thunder
gets this information at jit()
time: https://github.com/Lightning-AI/lightning-thunder/blob/94c94948b79875ba5247b5c986afa088d970a49d/thunder/common.py#L224-L226
So my question is: could we delay accessing this attribute until the function runs?
TBH, this is a very clear "don't do this, chaning the fn is completely unsupported!".
That said, we can talk about distributed-after-jit. The obstacles are:
I'll chat you up for understanding the need better.
triage review:
let's look at our current transforms, when they have to be applied, and what they mean when ordered after each other (are all orders supported)? For example, ddp jit grad? jit grad ddp?
do we need to support transforms that change the original module, or maybe produce a new module?
Currently the ddp transformation is applied during the JITing (i.e. while the interpreter runs). This is fundamentally incompatible with what you're trying to do.
I would appreciate some pointers or examples here that show this, because in my test, the trace does look correct as it contains the appropriate collectives added.
I'm probably misunderstanding how the interpreter works. How can the prologues be generated at jit
time if we don't have any input tensors for which to check shapes? I thought this would only happen on the first call
FSDP and DDP calls are not trace transforms, they are parameter annotators of the original to-be-jitted PyTorch module.
ddp(jit(model))
to work and is it more important to support than jit(ddp(model))
?We still need to support jit(ddp(model))
, as this is basically what happens whenever you jit a function and not the model.
What I'm advocating for is something like jit(ddp(undo_jit(jit(model)))
Where undo_jit
is currently the hack that I describe in the top-post.
Allowing this is convenient because then the user can control the innermost jit(model)
call but the framework (fabric, trainer) can control the transforms applied to the model and how they interact with each other if there are more than one.
I know nothing about Lightning. Do you want to allow users to do jit(model)
and then inside Lightning, you apply either DDP or FSDP call to a given model? FSDP is now working, right? You need something that unwraps the jit call. Have you tried using __wrapped__
? thunder.jit
uses functools.wraps
here: https://github.com/Lightning-AI/lightning-thunder/blob/6c64fb93b04672b731180afc4b63d5df55dae92f/thunder/__init__.py#L642
One thing (probably tangential) I was wondering, why is process_group_for_ddp
an attribute for CompileData
?
I think it would make sense to make it a property. Cause, if a scenario comes where we have to update CompileData.fn
, then we might miss updating these corresponding attributes. (This change could potentially also fix the issue)
diff --git a/thunder/common.py b/thunder/common.py
index 85775ff..24cabcb 100644
--- a/thunder/common.py
+++ b/thunder/common.py
@@ -218,10 +218,6 @@ class CompileData:
self.is_module = isinstance(self.fn, torch.nn.Module)
- # We set the process_group_for_ddp attribute on the module when
- # thunder.distributed.ddp(module) is called.
- self.process_group_for_ddp = getattr(self.fn, "process_group_for_ddp", None)
-
#
# Possibly processes the function
#
@@ -232,6 +228,12 @@ class CompileData:
assert disable_preprocessing, "please use thunder.compile if you need preprocessing"
+ @property
+ def process_group_for_ddp(self):
+ # We set the process_group_for_ddp attribute on the module when
+ # thunder.distributed.ddp(module) is called.
+ return getattr(self.fn, "process_group_for_ddp", None)
+
🐛 Bug
The snippet below looks hacky, but it's how I'm approaching support for having the user control the
thunder.jit
call outside of Fabric: https://github.com/Lightning-AI/litgpt/pull/1204The objective is that
fsdp|ddp
can be applied after thethunder.jit
call.It works with FSDP, but not with DDP where it fails with:
To Reproduce
torchrun --nproc-per-node 2 bug.py
cc @carmocca @awaelchli @crcrpar @kshitij12345 since you fixed a similar issue in #23