Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.2k stars 80 forks source link

Define `torchsymbol` for `torch.ops.higher_order.autograd_function_apply` #1106

Closed crcrpar closed 1 month ago

crcrpar commented 2 months ago

What does this PR do?

This could be useful when we see torch.ops.higher_order.autograd_function_apply which torch.compile introduces into a GraphModule if the graph includes a custom torch.autograd.Function.

t-vi commented 2 months ago

Given that this is dynamo-specific, this should go into the dynamo section rather than jit_ext, please.

IvanYashchuk commented 2 months ago

Given that this is dynamo-specific, this should go into the dynamo section rather than jit_ext, please.

I think you're confusing Dynamo and Functorch. It's a Functorch-specific thing. Functorch concepts integrate nicely with Thunder because it's oriented at providing functional style (no side-effects) equivalents of object-oriented PyTorch concepts. It's not difficult to convert torch.autograd.Function subclasses into a call to torch._functorch.autograd_function.autograd_function_apply carrying explicitly forward and backward definitions that in turn can be used by Thunder's reverse differentiation pass.

No Dynamo is involved in the following code:

import torch
from torch._functorch.autograd_function import autograd_function_apply

def forward(ctx, input):
    saved_for_backward = (input,)
    return input.sin(), saved_for_backward

def backward(ctx, grad_output, *saved_tensors):
    input, = saved_tensors
    return grad_output * input.cos()

def my_sin(input):
    return autograd_function_apply(forward, backward, input, args_tensor_mask = [True])

a = torch.randn(3, 4, requires_grad=True)
g = torch.ones_like(a)
my_sin(a).backward(g)
torch.testing.assert_close(a.grad, g * a.cos())

@t-vi, do you have other ideas on supporting using torch.autograd.Function.backward definitions properly in Thunder (https://github.com/Lightning-AI/lightning-thunder/issues/1017)?

t-vi commented 2 months ago

So I'm good with merging this, incremental progress and all, but would we want it to share the infra with #1125 ? We have a variation of this in #1123 , too.

I thought about pushing an update to the test with manipulated backward (happy to do it still if it saves time) and merge, but wanted to check if you want it merged as is and then follow-up or want to update this PR more.

crcrpar commented 2 months ago

would we want it to share the infra with #1125 ?

could you elaborate on this? torch.autograd.Function could heavily depend on ctx, the first argument of forward and backward, therefore we'd have to look into what's going on under the hood to tell intermediate tensors for backward. torch.ops.higher_order.autograd_function_apply however is quite free from it and we could just naively use fwd and bwd. On top of this point, IMHO having a logic in thunder/torch/__init__.py would be more readable than one in thunder/core/jit_ext.py but it feels quite hard for the logic for torch.autograd.Function to be in thunder.torch

t-vi commented 2 months ago

would we want it to share the infra with https://github.com/Lightning-AI/lightning-thunder/pull/1125 ?

could you elaborate on this?

torch.ops.higher_order.autograd_function_apply however is quite free from it and we could just naively use fwd and bwd. On top of this point, IMHO having a logic in thunder/torch/init.py would be more readable than one in thunder/core/jit_ext.py but it feels quite hard for the logic for torch.autograd.Function to be in thunder.torch

Yeah, so the advantage of having it here as it is seems to be that it lives in thunder.torch easily, the advantage of integrating it in the jit more would seem to be that one could trace into fwd and bwd, similar to what we do for torch.autograd.Function.

You are totally right about the function logic in jit_ext.py not being ideal. Maybe we could have a more structured JIT hooking facility that would take care of the wrapping bits, wraps the _interpet_call things (including returning error checking) and move the "purely autograd.Function" things out of jit_ext.py.

crcrpar commented 2 months ago

I'm not following why you maintain jit_ext would trace into fwd and bwd as I think the current implementation would also do. My original implementation was basically in jit_ext and I now think that was unnecessarily complicated due to the logic to create bsyms and traces

t-vi commented 2 months ago

I am not set on it. Either way works for me.

crcrpar commented 1 month ago

@t-vi @IvanYashchuk friendly ping