pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.69k stars 22.58k forks source link

Better export story for autograd.Function? #106388

Open zou3519 opened 1 year ago

zou3519 commented 1 year ago

🐛 Describe the bug

A trampoline appears in the exported graph of autograd.Function:

import torch
import torch._dynamo

class Foo(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x.sin()
    @staticmethod
    def backward(ctx, grad):
        x, = ctx.saved_tensors
        return grad * x.cos()

def f(x):
    return Foo.apply(x)

x = torch.randn([], requires_grad=True)
gm, *_ = torch._dynamo.export(f, x)
print(gm.code)

returns:

def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    function_ctx = torch.autograd.function.FunctionCtx()
    trampoline_autograd_apply = torch__dynamo_variables_misc_trampoline_autograd_apply(arg0);  arg0 = None
    return pytree.tree_unflatten([trampoline_autograd_apply], self._out_spec)

We probably want to turn autograd.Function into a legit HigherOrderOp at some point.

Error logs

No response

Minified repro

No response

Versions

main

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @wconstab

ezyang commented 1 year ago

Note that if you do aten_graph I would assume that this should evaporate. So the higher order op is mostly needed for pre-dispatch export. cc @andrewor14

BowenBao commented 1 year ago

@ezyang Is there a route where we can keep the autograd function op node without tracing into it, while doing aten_graph?

ezyang commented 1 year ago

You're kind of out of luck, because aten_graph requires you to produce an aten op, but there is no aten op for a custom autograd function. Perhaps this can be worked around by relaxing requirements (you don't really want to produce an aten op only graph, you are ok with opaque autograd function blobs), but I don't have full context on what's going on for your case.

thiagocrepaldi commented 8 months ago

@ezyang @suo is preserving

You're kind of out of luck, because aten_graph requires you to produce an aten op, but there is no aten op for a custom autograd function. Perhaps this can be worked around by relaxing requirements (you don't really want to produce an aten op only graph, you are ok with opaque autograd function blobs), but I don't have full context on what's going on for your case.

We would certainly be interested in having a mechanism to keep custom autograd functions as opaque operators which we could then replace by a proper ops as a post process step.

suo commented 8 months ago

@zou3519 is working on a new custom operator API that should do that: https://github.com/pytorch/pytorch/pull/120345