Open andrewor14 opened 1 year ago
Some thoughts from my offline conversation with Andrew:
This is behavior is technically expected: setting a model to .eval()
doesn't actually change whether or not autograd runs, and so pre_dispatch tracing will pick up on the no_grad()
call during tracing, putting a set_grad_enabled()
into the graph.
It feels to me like there are two high-level options:
(1) figure out how to make this work
The reason we want to trace the set_grad_enabled()
calls in the above example, is because they those calls run during normal model execution. inference_mode()
is technically the tool that we have to say "I want to skip all autograd behavior, which normally runs with the Autograd
dispatch key. One option would be to require the user to do the export in inference mode, and tell pre_dispatch tracing that it should not bake autograd API's like the above into the graph if inference mode is currently active.
This feels pretty fragile though. In particular, we also need to make sure that pre_dispatch tracing does the right thing on models that use autograd.Function
and torch.utils.checkpoint
, and we'd need some testing to make sure that these work properly
(2) Don't try to make this work
It sounds like the main motivation here is that QAT uses pre_dispatch
tracing, and we can minimize the amount of divergence across graphs if PTQ also uses pre_dispatch
tracing. Maybe this is something that we can fix / make less painful. The most glaring divergence (if you only care about inference) is that CompositeImplicitAutograd
decomps don't run in pre_dispatch tracing, but they run in "normal" tracing. There's probably some work that we can do here, to make it easier to run pre_dispatch tracing but still have all CompositeImplicitAutograd
decomps run. This PR is one step in the right direction: https://github.com/pytorch/pytorch/pull/105865, but we'd also need a convenient util to let a user grab all existing ops with a CompositeImplicitAutograd decomp.
Support for
export(..., pre_dispatch=True)
was added to make sure autograd (and autocast) functionality (e.g.torch.no_grad
) continue to work on exported graphs. This is needed for training on an exported model. The way this works is it inserts thesetorch._C._set_grad_enabled
ops into the graph.We have a use case where we want to use
export(..., pre_dispatch=True)
for models in eval mode. However, I just tried this out and looks like I'm still seeing thesetorch._C._set_grad_enabled
ops in the graph. This may lead to unexpected behavior during eval, where some gradients may be computed but we actually don't need them.Minimal repro:
Output: