pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

autograd.Function is silently incorrect on functorch transforms #207

Closed zou3519 closed 1 year ago

zou3519 commented 2 years ago

We should fix it. Not quite sure how though.

https://gist.github.com/zou3519/73a6189e21561f6ef5b42874e8a4826f

zou3519 commented 2 years ago

What’s going on in grad(grad(foo))(x):

foo(GradTensor(GradTensor(x)))

MyReLU.apply(GradTensor(GradTensor(x)))
> y = GradTensor(GradTensor(x)).clamp(min=0)
> > GradTensor(x).clamp(min=0)
> > > x.clamp(min=0)
y.grad_fn = MyRelu

The inner GradTensor doesn’t see the autograd.Function because it is not represented in the dispatcher :(

zou3519 commented 2 years ago

Here's a (public) writeup of what the problem is: https://docs.google.com/document/d/1sPRJyP_vkZEY3RbNBcy2hLmXxnvZlV7cB6_-uxyVMAE/edit?usp=sharing

AlphaBetaGamma96 commented 2 years ago

Hi @zou3519, I was just wondering if there's been any update on adding custom autograd Function support to FuncTorch? Thank you! :)

zou3519 commented 2 years ago

Hi @zou3519, I was just wondering if there's been any update on adding custom autograd Function support to FuncTorch? Thank you! :)

We're still thinking about how it would look like. It probably won't happen in the next couple of weeks though :( but we'll keep you updated!

rejuvyesh commented 2 years ago

Currently functorch gives the following error when it encounters autograd.Function:

functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. Please rewrite your function to not use autograd.Function while we work on fixing this

Are there any examples for how one can rewrite a function to not use autograd.Function if we need to define a custom backward/vjp? This error is I believe from without any nesting and just from the use of functorch.vjp.

zou3519 commented 2 years ago

@rejuvyesh the title was a bit misleading, it turns out autograd.Function can have silently incorrect behavior on all of our transforms.

Are there any examples for how one can rewrite a function to not use autograd.Function if we need to define a custom backward/vjp?

Not yet, but this is pretty top of mind for us

hturki commented 2 years ago

Just to confirm, does this mean that we can't use functorch to compute jacobians for anything that relies on a custom backward function? Pretty unfortunate considering the state of https://pytorch.org/docs/stable/generated/torch.autograd.functional.jacobian.html and how it points to functorch. Are there any updates on the horizon?

zou3519 commented 2 years ago

@hturki yes, that's unfortunately correct. This is still top of mind for us, but it is a bit difficult to implement. We've planned to implement it sometime in the next couple of months.

zou3519 commented 2 years ago

@hturki is there an example of a custom backward function you want to use with functorch?

hturki commented 2 years ago

Thanks for the update. I'd actually like to hook up a third-party library (tiny-cuda-nn: https://github.com/NVlabs/tiny-cuda-nn/blob/master/bindings/torch/tinycudann/modules.py#L41) which has pytorch bindings. It just currently happens to integrate with torch via the current autograd tooling - if there ends up being a workaround that directly integrates with functorch that would work as well

zou3519 commented 2 years ago

Thanks for the example. There are two types of autograd.Function:

That example looks like the latter. We don't support it yet, but, in the future, if someone wants to vmap over an autograd.Function whose forward and backward pass are both custom CUDA kernels, then they will need to write a custom batching rule so that vmap knows how to work with it.

zou3519 commented 1 year ago

We've finally gotten to revisiting this. I'm curious to hear about other use cases for this feature so we know what to prioritize. Calling on some folks (but please feel free to comment if you have thoughts!): @rejuvyesh, could you tell us a bit about your use case please?

rejuvyesh commented 1 year ago

For me, integration with vmap is less important. However, functorch's explicit interface is much nicer to generally integrate with other autodiff systems like those in say Julia. I use it in PyCallChainRules.jl so that users can integrate their differentiable pytorch functions in Julia. Biggest use cases is actually wrapping large cuda kernels which usually have a pytorch interface and dispatch performance is usually not that big of a concern. However, as @hturki found, it won't work with tiny-cuda-nn for example.

zou3519 commented 1 year ago

For me, integration with vmap is less important. However, functorch's explicit interface is much nicer to generally integrate with other autodiff systems like those in say Julia. I use it in PyCallChainRules.jl so that users can integrate their differentiable pytorch functions in Julia

Thanks for the feedback

zou3519 commented 1 year ago

@hturki (and @rejuvyesh if you have thoughts) - for hooking up a third-party library like tiny-cuda-nn: https://github.com/NVlabs/tiny-cuda-nn/blob/master/bindings/torch/tinycudann/modules.py#L41, there is a caveat for jacobian computation (and integration with vmap).

My understanding is that third-party libraries may define cuda kernels for the forward pass and backward pass that are stitched together via autograd.Function.

In order to compute its jacobian, we would need to be able to vmap over the backward pass. So functorch would provide a mechanism to add a vmap (batching) rule for the backward pass (like how autograd.Function allows you to define a custom gradient rule) and the user would have to add the batching rule. The batching rule would either be something naive (like running the backward pass in a for-loop), a call into a pre-written cuda kernel in the third-party library (sometimes one just has to finagle the inputs), or, in the worst case if the user really wanted something optimized, a new cuda kernel.

I just wanted to check - would something like this be helpful? I imagine writing a new cuda kernel would be a lot of work, but just the naive approach (running the backward pass in a for-loop) could still let you use the vmap API (and get speedups from other parts of the model)

floatingCatty commented 1 year ago

Hello, I've recently encountered a similar problem too, where I want to use functorch to compute gradient pass a customised torch.autograd.Function. But in my case, all operations in this customized Function, includeing the forward and backward pass are all implemented with pytorch operations. Is there any method that can make the functorch work in these circumstances?

ryanboustany commented 1 year ago

Hi @zou3519,

I have one small example if you are interested :

import torch
import torch.nn as nn
from torch.func import grad

class ReLUAlphaFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, alpha):
        ctx.save_for_backward(input)
        ctx.alpha = alpha
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        grad_input[input == 0] = ctx.alpha

        return grad_input, None

class ReLUAlpha(nn.Module):

    def __init__(self, alpha):
        super(ReLUAlpha, self).__init__()
        self.alpha = alpha

    def forward(self, input):
        return ReLUAlphaFunction.apply(input, self.alpha)

relu_10 = lambda:  ReLUAlpha(10)

input = torch.ones(1, 1, 4, 4)
gradient = grad(ReLUAlpha(10))(input)

RuntimeError: functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. Please rewrite your function to not use autograd.Function while we work on fixing this

Thanks for your help.

zou3519 commented 1 year ago

@ryanboustany if you're feeling ambitious, we're testing out the feature on the PyTorch master branch and would love early feedback. It's gated behind a feature flag and you'll have to download the latest PyTorch nightly binary to try it out. Here's what your script would look like with it (https://gist.github.com/zou3519/6e67fa31a8ebedfd1f198dc21e5d5993); nightly binary installation instructions over at https://pytorch.org/

ryanboustany commented 1 year ago

@ryanboustany if you're feeling ambitious, we're testing out the feature on the PyTorch master branch and would love early feedback. It's gated behind a feature flag and you'll have to download the latest PyTorch nightly binary to try it out. Here's what your script would look like with it (https://gist.github.com/zou3519/6e67fa31a8ebedfd1f198dc21e5d5993); nightly binary installation instructions over at https://pytorch.org/


Thank you for your helpful answer.

1- I have tested the ReLU custom gradient. Functions Grad and Jacrev from Jax work very well. However, for jacfwd, it blocks : ''NotImplementedError: You must implement the jvp function for custom autograd.Function to use it with forward mode AD.''

I wonder what the syntax is to be able to use jacfwd and code the custom gradient for forward AD. It would be interesting to compare what happens in forward AD vs reverse AD for custom gradients. Note that grad, jacrev, jacfwd works well if you use torch.nn.ReLU instead of the custom gradient.

2- I tested with a custom gradient of MaxPooling.


class MyMaxPool(torch.autograd.Function):
    generate_vmap_rule=True
    @staticmethod
    def forward(input, kernel_size):
        N, C, H, W = input.size()
        return torch.amax(torch.amax(input.reshape(N, C, H //kernel_size , kernel_size, W // kernel_size, kernel_size),3),4)

    @staticmethod
    def setup_context(ctx, inputs, output):
        input, kernel_size = inputs
        ctx.save_for_backward(input)
        ctx.kernel_size = kernel_size

    @staticmethod
    def backward(ctx, grad_output):

        input, = ctx.saved_tensors
        N, C, H, W = input.size()

        x_reshaped = input.reshape(N, C, H // ctx.kernel_size, ctx.kernel_size, W // ctx.kernel_size, ctx.kernel_size)
        out = torch.amax(torch.amax(x_reshaped,3),4)

        dx_reshaped = torch.zeros(x_reshaped.size())
        out_newaxis = out[:, :, :, None, :, None]
        mask = (x_reshaped == out_newaxis)
        dout_newaxis = grad_output[:, :, :, None, :, None]
        dout_broadcast = torch.broadcast_to(dout_newaxis, dx_reshaped.size())
        dx_reshaped[mask] = dout_broadcast[mask]
        grad_x = dx_reshaped.reshape(input.size())

        return grad_x, None, None

class MaxPoolings(nn.Module):

    def __init__(self, kernel_size):
        super(MaxPoolings, self).__init__()
        self.kernel_size = kernel_size
    def forward(self, input):
        return MyMaxPool.apply(input, self.kernel_size)

input = torch.randn(1, 1, 4, 4)
maxpool = MaxPoolings(2)
gradient = jacrev(lambda x: maxpool(x).sum())(input)

RuntimeError: vmap: index_put_(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of index_put_. If said operator is being called inside the PyTorch framework, please file a bug report instead.

I don't understand why this is blocking. Note that the grad function works well! Moreover, I wonder what I have to do to make jacfwd work too...

Thanks again. I hope this can help you.

zou3519 commented 1 year ago

1- I have tested with ReLU custom gradient. Functions Grad and Jacrev from Jax work very well. However, for jacfwd, it blocks : ''NotImplementedError: You must implement the jvp function for custom autograd.Function to use it with forward mode AD.''

See https://pytorch.org/docs/stable/notes/extending.html#forward-mode-ad for context. The TL;DR is that you'd need to add a jvp staticmethod that defines what the forward-mode AD gradient should be.

That being said, support for jacfwd is actually still a WIP (you'll need to wait until after https://github.com/pytorch/pytorch/pull/91211/ is merged, will likely be later today/tomorrow), there is an example here for how to write a jvp staticmethod.

2- I tested with a custom gradient of MaxPooling.

Change dx_reshaped = torch.zeros(x_reshaped.size()) to dx_reshaped = grad_output.new_zeros(x_reshaped.size()). A longer explanation is over at https://pytorch.org/functorch/stable/ux_limitations.html#mutation-in-place-pytorch-operations, but this is a vmap limitation: we need to construct a new Tensor using the original tensor as a base

JackShi9 commented 1 year ago

A new potential use case pytorch/pytorch#91629 ? @zou3519

ryanboustany commented 1 year ago

@zou3519 I tried to add a jvp method. But it doesn't seem to give the desired result. Do you have an idea? Do you know what the output of the jvp method returns? Here it doesn't return my output_tangent matrix.

Thank you.

import torch
import torch.nn as nn
from torch.func import grad, jacrev, jacfwd

torch._C._set_autograd_function_extension_enabled(True)

class ReLUAlphaFunction(torch.autograd.Function):
    generate_vmap_rule=True
    @staticmethod
    def forward(input, alpha):
        return input.clamp(min=0)

    @staticmethod
    def setup_context(ctx, inputs, output):
        input, alpha = inputs
        ctx.save_for_backward(input)
        ctx.save_for_forward(input)
        ctx.alpha = alpha

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        grad_input[input == 0] = ctx.alpha

        return grad_input, None

    @staticmethod
    def jvp(ctx, *tangents):
        input, = ctx.saved_tensors
        output_tangent = input.clone()
        output_tangent[input < 0] = 0
        output_tangent[input > 0] = 1
        output_tangent[input == 0] = ctx.alpha
        return output_tangent

class ReLUAlpha(nn.Module):

    def __init__(self, alpha):
        super(ReLUAlpha, self).__init__()
        self.alpha = alpha

    def forward(self, input):
        return ReLUAlphaFunction.apply(input, self.alpha)

input = torch.tensor([[0.0,2.0, 3.0],[1.0,2.0, 3.0]])
relu_10 = ReLUAlpha(10)
print(jacfwd(lambda x: relu_10(x).sum())(input))
print(jacrev(lambda x: relu_10(x).sum())(input))

Output : tensor([[15., 15., 15.], [15., 15., 15.]])

tensor([[10., 1., 1.], [ 1., 1., 1.]])

zou3519 commented 1 year ago

@ryanboustany the jvp method should look like this: https://gist.github.com/zou3519/89790dc3e94e3c3fde313886143c06fa (you cloned the wrong tensor).

That being said, the result from jacfwd is different from the result of jacrev. I think this is expected because the autograd.Function isn't computing mathematically correct gradients.

EDIT: I was able to get the jacrev and jacfwd to return the same value. I think this is what you want. (https://gist.github.com/zou3519/726e7929e63b3696c0b97d4e3ae852a1).

For some more context, the jvp method is supposed to compute the jacobian-vector product. That's the jacobian of the Function evaluated at the input matrix-multiplied with the tangent (the "vector"). Underneath vmap, jacfwd(relu_10)(input) essentially does a vmap over the jvp in such a way that the output of autograd.Function.jvp is the full jacobian

ryanboustany commented 1 year ago

@zou3519 Indeed, your code does what I want to do. For me, jacrev and jacfwd are supposed to compute the same thing (modulo the numerical precision - 32 bits or 64 bits). In fact, they compute the same thing in theory but do not do the same numerical operations. Here is an example:

class ReLUAlphaFunction(torch.autograd.Function):
    generate_vmap_rule=True
    @staticmethod
    def forward(input, alpha):
        return input.clamp(min=0)

    @staticmethod
    def setup_context(ctx, inputs, output):
        input, alpha = inputs
        ctx.save_for_backward(input)
        ctx.save_for_forward(input)
        ctx.alpha = alpha

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        grad_input = torch.where((input == 0) * (grad_input != 0), ctx.alpha, grad_input)
        return grad_input, None

    @staticmethod
    def jvp(ctx, *tangents):
        input, = ctx.saved_tensors
        tangent, _ = tangents
        output_tangent = tangent.clone()
        output_tangent[input < 0] = 0
        output_tangent = torch.where((input == 0) * (output_tangent != 0), ctx.alpha, output_tangent)
        #import pdb; pdb.set_trace()
        return output_tangent

class ReLUAlpha(nn.Module):

    def __init__(self, alpha):
        super(ReLUAlpha, self).__init__()
        self.alpha = alpha

    def forward(self, input):
        return ReLUAlphaFunction.apply(input, self.alpha)

relu_10 = ReLUAlpha(10)

model = torch.nn.Sequential(
          nn.Linear(10, 10, bias=False),
          relu_10,
          nn.Linear(10,2, bias=False))

def f(x):
    return model(x).sum()

x = torch.randn(1,1,10,10)
x[0][0][0][0] = 0.0
x.requires_grad_()

loss = f(x)
loss.backward()

jacobian_fwd = jacfwd(f)(x)
jacobian_bwd = jacrev(f)(x)
jacobian_grad = grad(f)(x)

print(torch.linalg.norm(x.grad - jacobian_fwd))
print(torch.linalg.norm(jacobian_bwd  - jacobian_fwd))
print(torch.linalg.norm(jacobian_grad  - jacobian_fwd))

Output : tensor(1.1031e-07, grad_fn=) tensor(1.1031e-07, grad_fn=) tensor(1.1031e-07, grad_fn=)


Anyway, customizing gradient and using jacrev/jacfwd works very well! Thanks a lot!

Last thing, we should succeed in calculating jacrevand jacfwd with respect to the training parameters of the model and not only with respect to the input. Something that would make this functionality :

grad_neural1 = [param.grad for param in model.parameters()]. 

I don't have the solution yet.

grad_neural1_fwd = [jacfwd(f)(param) for param in model.parameters()]

Seem doesn't work. It considers param to be an input and not a weight.

zou3519 commented 1 year ago

@ryanboustany one needs to write a function that accepts parameters in order to compute jacobians for those parameters. The API to do that in functorch is make_functional: https://pytorch.org/functorch/stable/generated/functorch.make_functional.html?highlight=make_functional#functorch.make_functional

Also, if you need more support, please open another issue and let's talk more there -- I'm about to mark this one as complete :)

zou3519 commented 1 year ago

To everyone else in this thread:

autograd.Function now works with functorch! The functionality is now in the PyTorch nightly binaries. The TL;DR is that you'll have to refactor existing autograd.Functions in order to use them.

Please see https://pytorch.org/docs/master/notes/extending.func.html for more details. We'd love to get some more feedback on this, so please open more issues if you run into problems.