pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Semantic discrepancy on requires_grad after compiling Tensor.detach #1052

Closed sangongs closed 1 year ago

sangongs commented 1 year ago

Reproduce:

import torch
from functorch.compile import aot_function

def fn(x):
    return x.detach()

aot_fn = aot_function(fn, fw_compiler=lambda fx_module, _: fx_module)

x = torch.randn(1, requires_grad=True)
ref = fn(x)
res = aot_fn(x)

assert(ref.requires_grad == res.requires_grad)

PyTorch version: 1.13.0.dev20220929+cu116

Not sure if this is related to #376.

samdow commented 1 year ago

cc @bdhirsh might be totally off but is this related to any of the work that you were doing to make requires_grad track correctly on proxies?

bdhirsh commented 1 year ago

Hmm I don't think so. It looks like it's because we're compiling the whole thing (including the detach() call) into an autograd.Function, and autograd.Function will unconditionally mark all of its forward outputs as requiring gradients.

@albanD brought up a good point - in aot autograd, we already run the forward once to get the expected output(s) to pass to the joint graph for tracing, and that point we should know the expected requires-gradness of every forward output. We can use autograd.function's mark-nondifferentiable API, to (statically) mark those outputs as not requiring gradients, which would fix this problem.

That would technically make the autograd.Function() that we create do the wrong thing if you re-used it with inputs that have different values set for .requires_grad. But we're hiding behind dynamo, and dynamo already specializes on requires_grad-ness today, so we can expect to trace out a new autograd.Function object whenever tht happens.

bdhirsh commented 1 year ago

Here's a potential fix based on my discussion with Alban: https://github.com/pytorch/pytorch/pull/86838

sangongs commented 1 year ago

Thanks to @bdhirsh for the quick fix. However, the following program still fails after cherry-picking the PR:

import torch
from functorch.compile import aot_function, make_boxed_func
from torchinductor.compile_fx import compile_fx_inner

def fn(x):
    y = x.view(-1).detach()
    return y

aot_fn = aot_function(fn, fw_compiler=compile_fx_inner)

x = torch.randn(1, 2, requires_grad=True)
ref = fn(x)
res = aot_fn(x)

assert(ref.requires_grad == res.requires_grad)
bdhirsh commented 1 year ago

It looks like that repro runs ok with the aot-eager backend, but not with inductor:

import torch
from functorch.compile import aot_function, make_boxed_func, not
from torchinductor.compile_fx import compile_fx_inner

def fn(x):
    y = x.view(-1).detach()
    return y

aot_fn = aot_function(fn, fw_compiler=nop)

x = torch.randn(1, 2, requires_grad=True)
ref = fn(x)
res = aot_fn(x)

assert(ref.requires_grad == res.requires_grad)

When I print the output of inductor's codegen, I get:

def call(args):
    primals_1, = args
    args.clear()
    primals_1_size = primals_1.size()
    s0 = primals_1_size[1]
    return (as_strided(primals_1, (s0, ), (1, )), )

Where it looks like inductor is treating the .detach() call(s) in the original graph as no-ops.

To be fair, it seems fair to argue that inductor shouldn't have to worry about requires_grad when it's compiling? I'm not exactly sure what the fix is though. It looks like even though we're calling mark_non_differentiable() on the outputs in aot autograd, they're still being set with requires_grad=True.

sangongs commented 1 year ago

Where it looks like inductor is treating the .detach() call(s) in the original graph as no-ops.

Yes, Inductor treats .detach() calls as no-ops: https://github.com/pytorch/torchdynamo/blob/986da5a19055e99901220ffdc18b80558b54aa7b/torchinductor/lowering.py#L467-L469

To be fair, it seems fair to argue that inductor shouldn't have to worry about requires_grad when it's compiling?

Agree. It will be good if AOT autograd can handle this automatically. Although, in theory, Inductor could generate code to handle detach.

bdhirsh commented 1 year ago

@albanD Does this sound like correct autograd.Function behavior? Based on the docs, I would have expected any tensors marked with mark_non_differentiable in the forward as having requires_grad=False.

If that sounds like incorrect behavior, I can dig into autograd.Function a bit more. Alternatively, if this is a limitation of autograd.Function then our options are probably either to move on to something else (like what Richard has brought up before), or properly handle requires_grad-ness in inductor. Here's my example:

class CompiledFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, a):
        out = torch.as_strided(a, a.shape, a.stride())
        # Explicitly mark the output as being non-differentiable
        # EVEN IF it appears to require gradients.
        ctx.mark_non_differentiable(out)
        return out

    @staticmethod
    def backward(ctx, *flat_args):
        # ignore, not called
        return tuple(flat_args)

a = torch.ones(2, 2, requires_grad=True)
b = CompiledFunction.apply(a)
# Prints true. Even though we marked it as non-differentiable?
print(b.requires_grad)
sangongs commented 1 year ago

I guess I found the reason why mark_non_differentiable does not unset requires_grad in this case. It might be because of this piece of code: https://github.com/pytorch/pytorch/blob/ae45dab57e22e3d04516e7dd81ef8dbefd51bfe3/torch/csrc/autograd/custom_function.cpp#L290-L299

Basically, if the output is a view, then mark_non_differentiable() takes no effect on it.

Maybe a stupid question. But can we just apply .detach() onto non-differentiable outputs instead of mark_non_differentiable()?

bdhirsh commented 1 year ago

@sangongs nice catch. I'll defer to Alban, but... I think that sees reasonable (it feels bad, because it looks like we're trying to ignore autograd.Function's existing behavior, but the only reason for that is because there was a .detach() in the original graph that the compiler removed, so... we're adding it back).

sangongs commented 1 year ago

I came up with a work-around in Inductor to deal with this special tensor.view().detach() case: https://github.com/pytorch/torchdynamo/pull/1661

albanD commented 1 year ago

This is a good catch. In general, indeed, setting requires_grad on a differentiable view has no effect as it's t.requires_grad field's value is set to reflect its base's requires_grad-ness.

In this case, if the user explicitely state that this is not differentiable, then we should properly detach as it is a non-differentiable view. cc @soulitzer I think this is something we want to solve on the custom Function side.

bdhirsh commented 1 year ago

@albanD to confirm - you think that this is something that should be handled transparently by autograd.function?

aka if one of the tensors that the user marks with ctx.set_non_differentiable(...) is a differentiable view, autograd.function should implicitly .detach() it?

albanD commented 1 year ago

you think that this is something that should be handled transparently by autograd.function?

Yes

sangongs commented 1 year ago

Looks like the issue is still not fixed for backends like inductor that do not handle detach(). @bdhirsh Do you have plan to implement this:

aka if one of the tensors that the user marks with ctx.set_non_differentiable(...) is a differentiable view, autograd.function should implicitly .detach() it?