pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.4k stars 22.73k forks source link

`export(..., pre_dispatch=True)` for model in eval mode still inserts autograd ops #106137

Open andrewor14 opened 1 year ago

andrewor14 commented 1 year ago

Support for export(..., pre_dispatch=True) was added to make sure autograd (and autocast) functionality (e.g. torch.no_grad) continue to work on exported graphs. This is needed for training on an exported model. The way this works is it inserts these torch._C._set_grad_enabled ops into the graph.

We have a use case where we want to use export(..., pre_dispatch=True) for models in eval mode. However, I just tried this out and looks like I'm still seeing these torch._C._set_grad_enabled ops in the graph. This may lead to unexpected behavior during eval, where some gradients may be computed but we actually don't need them.

Minimal repro:

import torch
import torch._dynamo

class ToyModel(torch.nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.linear2 = torch.nn.Linear(10, 5)

    def forward(self, x):
        x = self.linear1(x)
        with torch.no_grad():
            x = self.linear2(x)
        return x

example_inputs = (torch.randn(10, 10),)
m = ToyModel().eval()
m, _ = torch._dynamo.export(
    m,
    *example_inputs,
    aten_graph=True,
    pre_dispatch=True,
)
print(m)

Output:

def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    _param_constant0 = self._param_constant0
    _param_constant1 = self._param_constant1
    linear_default = torch.ops.aten.linear.default(arg0, _param_constant0, _param_constant1);  arg0 = _param_constant0 = _param_constant1 = None
    _set_grad_enabled = torch._C._set_grad_enabled(False)
    _param_constant2 = self._param_constant2
    _param_constant3 = self._param_constant3
    linear_default_1 = torch.ops.aten.linear.default(linear_default, _param_constant2, _param_constant3);  linear_default = _param_constant2 = _param_constant3 = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True)
    return pytree.tree_unflatten([linear_default_1], self._out_spec)
bdhirsh commented 1 year ago

Some thoughts from my offline conversation with Andrew:

This is behavior is technically expected: setting a model to .eval() doesn't actually change whether or not autograd runs, and so pre_dispatch tracing will pick up on the no_grad() call during tracing, putting a set_grad_enabled() into the graph.

It feels to me like there are two high-level options:

(1) figure out how to make this work

The reason we want to trace the set_grad_enabled() calls in the above example, is because they those calls run during normal model execution. inference_mode() is technically the tool that we have to say "I want to skip all autograd behavior, which normally runs with the Autograd dispatch key. One option would be to require the user to do the export in inference mode, and tell pre_dispatch tracing that it should not bake autograd API's like the above into the graph if inference mode is currently active.

This feels pretty fragile though. In particular, we also need to make sure that pre_dispatch tracing does the right thing on models that use autograd.Function and torch.utils.checkpoint, and we'd need some testing to make sure that these work properly

(2) Don't try to make this work

It sounds like the main motivation here is that QAT uses pre_dispatch tracing, and we can minimize the amount of divergence across graphs if PTQ also uses pre_dispatch tracing. Maybe this is something that we can fix / make less painful. The most glaring divergence (if you only care about inference) is that CompositeImplicitAutograd decomps don't run in pre_dispatch tracing, but they run in "normal" tracing. There's probably some work that we can do here, to make it easier to run pre_dispatch tracing but still have all CompositeImplicitAutograd decomps run. This PR is one step in the right direction: https://github.com/pytorch/pytorch/pull/105865, but we'd also need a convenient util to let a user grab all existing ops with a CompositeImplicitAutograd decomp.