NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.61k stars 256 forks source link

Passing context_fn to torch.utils.checkpoint results in errors when using torch.compile #890

Closed MaciejBalaNV closed 3 weeks ago

MaciejBalaNV commented 4 weeks ago

In transformer_engine.pytorch.distributed.checkpoint TE passes context_fn to torch.utils.checkpoint.checkpoint even if no context_fn was passed by the user to original function, in which case noop_context_fn is used. It's breaking when we are using torch.compile, for two reasons:

  1. PyTorch is wrapping context_fn in LazyVariableTracker, which results in the error thrown here, as LazyVariableTracker is not checked for in the if chain.
  2. Even if we modify Pytorch to check for LazyVariableTracker e.g. by adding to the if chain
            elif isinstance(ctx, torch._dynamo.variables.LazyVariableTracker) and ctx.value == noop_context_fn:
                context_fn = noop_context_fn

the error is still thrown here, as PyTorch only allows context_fn in combination with torch.compile to be passed when using _experimental_support_context_fn_in_torch_utils_checkpoint config.

Seems like the best option, which doesn't require modifying PyTorch and using experimental config, is to simply not pass context_fn to torch.utils.checkpoint.checkpoint when context_fn is not specified by the user. I confirmed that this solution works with torch.compile without an issue.

ptrendx commented 4 weeks ago

@denera Could you look into this?

Also FYI @ptrblck.

denera commented 4 weeks ago

@MaciejBalaNV Thanks for reporting this! While investigating your issue, I've found out that it is not actually limited to the noop_context_fn. torch.compile + te.distributed.checkpoint is causing issues with other context-based features like torch.amp.autocast() compatibility too.

What we actually need here is a @torch._disable_dynamo decorator on the TE checkpoint. This is also what the native PyTorch checkpoint uses to avoid torch.compile issues with context functions. In fact, the default behavior in the native PyTorch checkpoint is to use noop_context_fn when none is provided, so even if TE called it without a context function, it would still internally use noop_context_fn. It doesn't cause issues for them because they disable TorchDynamo for the checkpoint wrapper, and we should be doing the same in TE.

I've verified that this works on my end and I will be pushing up a PR for it shortly. I'd appreciate it if you could also confirm it resolves your issue. Thanks!

MaciejBalaNV commented 4 weeks ago

Hey @denera, thanks a lot for looking into it. It's useful to know that other context managers also have issues with torch.compile combined with te.checkpoint.

I just wanted to clarify one thing - there is a difference between passing context_fn=noop_context_fn and not passing context_fn at all. This difference comes from the fact that context_fn is not part of kwargs in this check, which means that no context manager is used at all.

denera commented 3 weeks ago

@MaciejBalaNV It looks like the line you linked is checking for both 'context_fn' in kwargs and kwargs['context_fn'] != noop_context_fn, so passing in context_fn=noop_context_fn here will fail at the second condition and execute the same code path as not passing a context function at all. Am I misunderstanding something there?

MaciejBalaNV commented 3 weeks ago

The thing is that kwargs can be wrapped in VariableTrackers when using torch.compile. For example that's why we have this check in PyTorch:

            if isinstance(ctx, torch._dynamo.variables.UserFunctionVariable):
                context_fn = ctx.fn

In my experiments, the context_fn was wrapped as LazyVariableTracker. It means that kwargs['context_fn'] wasn't equal to noop_context_fn.