Closed crcrpar closed 1 month ago
Given that this is dynamo-specific, this should go into the dynamo section rather than jit_ext, please.
Given that this is dynamo-specific, this should go into the dynamo section rather than jit_ext, please.
I think you're confusing Dynamo and Functorch. It's a Functorch-specific thing. Functorch concepts integrate nicely with Thunder because it's oriented at providing functional style (no side-effects) equivalents of object-oriented PyTorch concepts. It's not difficult to convert torch.autograd.Function
subclasses into a call to torch._functorch.autograd_function.autograd_function_apply
carrying explicitly forward and backward definitions that in turn can be used by Thunder's reverse differentiation pass.
No Dynamo is involved in the following code:
import torch
from torch._functorch.autograd_function import autograd_function_apply
def forward(ctx, input):
saved_for_backward = (input,)
return input.sin(), saved_for_backward
def backward(ctx, grad_output, *saved_tensors):
input, = saved_tensors
return grad_output * input.cos()
def my_sin(input):
return autograd_function_apply(forward, backward, input, args_tensor_mask = [True])
a = torch.randn(3, 4, requires_grad=True)
g = torch.ones_like(a)
my_sin(a).backward(g)
torch.testing.assert_close(a.grad, g * a.cos())
@t-vi, do you have other ideas on supporting using torch.autograd.Function.backward
definitions properly in Thunder (https://github.com/Lightning-AI/lightning-thunder/issues/1017)?
So I'm good with merging this, incremental progress and all, but would we want it to share the infra with #1125 ? We have a variation of this in #1123 , too.
I thought about pushing an update to the test with manipulated backward (happy to do it still if it saves time) and merge, but wanted to check if you want it merged as is and then follow-up or want to update this PR more.
would we want it to share the infra with #1125 ?
could you elaborate on this?
torch.autograd.Function
could heavily depend on ctx
, the first argument of forward
and backward
, therefore we'd have to look into what's going on under the hood to tell intermediate tensors for backward.
torch.ops.higher_order.autograd_function_apply
however is quite free from it and we could just naively use fwd
and bwd
.
On top of this point, IMHO having a logic in thunder/torch/__init__.py
would be more readable than one in thunder/core/jit_ext.py
but it feels quite hard for the logic for torch.autograd.Function
to be in thunder.torch
would we want it to share the infra with https://github.com/Lightning-AI/lightning-thunder/pull/1125 ?
could you elaborate on this?
torch.ops.higher_order.autograd_function_apply however is quite free from it and we could just naively use fwd and bwd. On top of this point, IMHO having a logic in thunder/torch/init.py would be more readable than one in thunder/core/jit_ext.py but it feels quite hard for the logic for torch.autograd.Function to be in thunder.torch
Yeah, so the advantage of having it here as it is seems to be that it lives in thunder.torch easily, the advantage of integrating it in the jit more would seem to be that one could trace into fwd and bwd, similar to what we do for torch.autograd.Function.
You are totally right about the function logic in jit_ext.py
not being ideal. Maybe we could have a more structured JIT hooking facility that would take care of the wrapping bits, wraps the _interpet_call
things (including returning error checking) and move the "purely autograd.Function" things out of jit_ext.py
.
I'm not following why you maintain jit_ext would trace into fwd and bwd as I think the current implementation would also do. My original implementation was basically in jit_ext and I now think that was unnecessarily complicated due to the logic to create bsyms and traces
I am not set on it. Either way works for me.
@t-vi @IvanYashchuk friendly ping
What does this PR do?
This could be useful when we see
torch.ops.higher_order.autograd_function_apply
which torch.compile introduces into a GraphModule if the graph includes a customtorch.autograd.Function
.