Closed MaciejBalaNV closed 3 weeks ago
@denera Could you look into this?
Also FYI @ptrblck.
@MaciejBalaNV Thanks for reporting this! While investigating your issue, I've found out that it is not actually limited to the noop_context_fn
. torch.compile
+ te.distributed.checkpoint
is causing issues with other context-based features like torch.amp.autocast()
compatibility too.
What we actually need here is a @torch._disable_dynamo
decorator on the TE checkpoint. This is also what the native PyTorch checkpoint uses to avoid torch.compile
issues with context functions. In fact, the default behavior in the native PyTorch checkpoint is to use noop_context_fn
when none is provided, so even if TE called it without a context function, it would still internally use noop_context_fn
. It doesn't cause issues for them because they disable TorchDynamo for the checkpoint wrapper, and we should be doing the same in TE.
I've verified that this works on my end and I will be pushing up a PR for it shortly. I'd appreciate it if you could also confirm it resolves your issue. Thanks!
Hey @denera, thanks a lot for looking into it. It's useful to know that other context managers also have issues with torch.compile
combined with te.checkpoint
.
I just wanted to clarify one thing - there is a difference between passing context_fn=noop_context_fn
and not passing context_fn
at all. This difference comes from the fact that context_fn
is not part of kwargs
in this check, which means that no context manager is used at all.
@MaciejBalaNV It looks like the line you linked is checking for both 'context_fn' in kwargs
and kwargs['context_fn'] != noop_context_fn
, so passing in context_fn=noop_context_fn
here will fail at the second condition and execute the same code path as not passing a context function at all. Am I misunderstanding something there?
The thing is that kwargs can be wrapped in VariableTrackers when using torch.compile
. For example that's why we have this check in PyTorch:
if isinstance(ctx, torch._dynamo.variables.UserFunctionVariable):
context_fn = ctx.fn
In my experiments, the context_fn
was wrapped as LazyVariableTracker
. It means that kwargs['context_fn']
wasn't equal to noop_context_fn
.
In transformer_engine.pytorch.distributed.checkpoint TE passes
context_fn
totorch.utils.checkpoint.checkpoint
even if nocontext_fn
was passed by the user to original function, in which casenoop_context_fn
is used. It's breaking when we are usingtorch.compile
, for two reasons:context_fn
inLazyVariableTracker
, which results in the error thrown here, asLazyVariableTracker
is not checked for in theif
chain.LazyVariableTracker
e.g. by adding to theif
chainthe error is still thrown here, as PyTorch only allows
context_fn
in combination withtorch.compile
to be passed when using_experimental_support_context_fn_in_torch_utils_checkpoint
config.Seems like the best option, which doesn't require modifying PyTorch and using experimental config, is to simply not pass
context_fn
totorch.utils.checkpoint.checkpoint
whencontext_fn
is not specified by the user. I confirmed that this solution works withtorch.compile
without an issue.