Closed sangongs closed 1 year ago
cc @bdhirsh might be totally off but is this related to any of the work that you were doing to make requires_grad track correctly on proxies?
Hmm I don't think so. It looks like it's because we're compiling the whole thing (including the detach()
call) into an autograd.Function
, and autograd.Function
will unconditionally mark all of its forward outputs as requiring gradients.
@albanD brought up a good point - in aot autograd, we already run the forward once to get the expected output(s) to pass to the joint graph for tracing, and that point we should know the expected requires-gradness of every forward output. We can use autograd.function's mark-nondifferentiable API, to (statically) mark those outputs as not requiring gradients, which would fix this problem.
That would technically make the autograd.Function()
that we create do the wrong thing if you re-used it with inputs that have different values set for .requires_grad
. But we're hiding behind dynamo, and dynamo already specializes on requires_grad-ness today, so we can expect to trace out a new autograd.Function
object whenever tht happens.
Here's a potential fix based on my discussion with Alban: https://github.com/pytorch/pytorch/pull/86838
Thanks to @bdhirsh for the quick fix. However, the following program still fails after cherry-picking the PR:
import torch
from functorch.compile import aot_function, make_boxed_func
from torchinductor.compile_fx import compile_fx_inner
def fn(x):
y = x.view(-1).detach()
return y
aot_fn = aot_function(fn, fw_compiler=compile_fx_inner)
x = torch.randn(1, 2, requires_grad=True)
ref = fn(x)
res = aot_fn(x)
assert(ref.requires_grad == res.requires_grad)
It looks like that repro runs ok with the aot-eager backend, but not with inductor:
import torch
from functorch.compile import aot_function, make_boxed_func, not
from torchinductor.compile_fx import compile_fx_inner
def fn(x):
y = x.view(-1).detach()
return y
aot_fn = aot_function(fn, fw_compiler=nop)
x = torch.randn(1, 2, requires_grad=True)
ref = fn(x)
res = aot_fn(x)
assert(ref.requires_grad == res.requires_grad)
When I print the output of inductor's codegen, I get:
def call(args):
primals_1, = args
args.clear()
primals_1_size = primals_1.size()
s0 = primals_1_size[1]
return (as_strided(primals_1, (s0, ), (1, )), )
Where it looks like inductor is treating the .detach()
call(s) in the original graph as no-ops.
To be fair, it seems fair to argue that inductor shouldn't have to worry about requires_grad
when it's compiling? I'm not exactly sure what the fix is though. It looks like even though we're calling mark_non_differentiable()
on the outputs in aot autograd, they're still being set with requires_grad=True
.
Where it looks like inductor is treating the
.detach()
call(s) in the original graph as no-ops.
Yes, Inductor treats .detach()
calls as no-ops:
https://github.com/pytorch/torchdynamo/blob/986da5a19055e99901220ffdc18b80558b54aa7b/torchinductor/lowering.py#L467-L469
To be fair, it seems fair to argue that inductor shouldn't have to worry about
requires_grad
when it's compiling?
Agree. It will be good if AOT autograd can handle this automatically. Although, in theory, Inductor could generate code to handle detach
.
@albanD Does this sound like correct autograd.Function
behavior? Based on the docs, I would have expected any tensors marked with mark_non_differentiable
in the forward as having requires_grad=False
.
If that sounds like incorrect behavior, I can dig into autograd.Function
a bit more. Alternatively, if this is a limitation of autograd.Function
then our options are probably either to move on to something else (like what Richard has brought up before), or properly handle requires_grad
-ness in inductor. Here's my example:
class CompiledFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, a):
out = torch.as_strided(a, a.shape, a.stride())
# Explicitly mark the output as being non-differentiable
# EVEN IF it appears to require gradients.
ctx.mark_non_differentiable(out)
return out
@staticmethod
def backward(ctx, *flat_args):
# ignore, not called
return tuple(flat_args)
a = torch.ones(2, 2, requires_grad=True)
b = CompiledFunction.apply(a)
# Prints true. Even though we marked it as non-differentiable?
print(b.requires_grad)
I guess I found the reason why mark_non_differentiable
does not unset requires_grad
in this case. It might be because of this piece of code:
https://github.com/pytorch/pytorch/blob/ae45dab57e22e3d04516e7dd81ef8dbefd51bfe3/torch/csrc/autograd/custom_function.cpp#L290-L299
Basically, if the output is a view, then mark_non_differentiable()
takes no effect on it.
Maybe a stupid question. But can we just apply .detach()
onto non-differentiable outputs instead of mark_non_differentiable()
?
@sangongs nice catch. I'll defer to Alban, but... I think that sees reasonable (it feels bad, because it looks like we're trying to ignore autograd.Function
's existing behavior, but the only reason for that is because there was a .detach()
in the original graph that the compiler removed, so... we're adding it back).
I came up with a work-around in Inductor to deal with this special tensor.view().detach()
case: https://github.com/pytorch/torchdynamo/pull/1661
This is a good catch.
In general, indeed, setting requires_grad on a differentiable view has no effect as it's t.requires_grad
field's value is set to reflect its base's requires_grad-ness.
In this case, if the user explicitely state that this is not differentiable, then we should properly detach as it is a non-differentiable view. cc @soulitzer I think this is something we want to solve on the custom Function side.
@albanD to confirm - you think that this is something that should be handled transparently by autograd.function
?
aka if one of the tensors that the user marks with ctx.set_non_differentiable(...)
is a differentiable view, autograd.function
should implicitly .detach()
it?
you think that this is something that should be handled transparently by autograd.function?
Yes
Looks like the issue is still not fixed for backends like inductor that do not handle detach()
. @bdhirsh Do you have plan to implement this:
aka if one of the tensors that the user marks with ctx.set_non_differentiable(...) is a differentiable view, autograd.function should implicitly .detach() it?
Reproduce:
PyTorch version: 1.13.0.dev20220929+cu116
Not sure if this is related to #376.