Closed zou3519 closed 1 year ago
What’s going on in grad(grad(foo))(x):
foo(GradTensor(GradTensor(x)))
MyReLU.apply(GradTensor(GradTensor(x)))
> y = GradTensor(GradTensor(x)).clamp(min=0)
> > GradTensor(x).clamp(min=0)
> > > x.clamp(min=0)
y.grad_fn = MyRelu
The inner GradTensor doesn’t see the autograd.Function because it is not represented in the dispatcher :(
Here's a (public) writeup of what the problem is: https://docs.google.com/document/d/1sPRJyP_vkZEY3RbNBcy2hLmXxnvZlV7cB6_-uxyVMAE/edit?usp=sharing
Hi @zou3519, I was just wondering if there's been any update on adding custom autograd Function support to FuncTorch? Thank you! :)
Hi @zou3519, I was just wondering if there's been any update on adding custom autograd Function support to FuncTorch? Thank you! :)
We're still thinking about how it would look like. It probably won't happen in the next couple of weeks though :( but we'll keep you updated!
Currently functorch
gives the following error when it encounters autograd.Function
:
functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. Please rewrite your function to not use autograd.Function while we work on fixing this
Are there any examples for how one can rewrite a function to not use autograd.Function
if we need to define a custom backward
/vjp
? This error is I believe from without any nesting and just from the use of functorch.vjp
.
@rejuvyesh the title was a bit misleading, it turns out autograd.Function can have silently incorrect behavior on all of our transforms.
Are there any examples for how one can rewrite a function to not use autograd.Function if we need to define a custom backward/vjp?
Not yet, but this is pretty top of mind for us
Just to confirm, does this mean that we can't use functorch to compute jacobians for anything that relies on a custom backward function? Pretty unfortunate considering the state of https://pytorch.org/docs/stable/generated/torch.autograd.functional.jacobian.html and how it points to functorch. Are there any updates on the horizon?
@hturki yes, that's unfortunately correct. This is still top of mind for us, but it is a bit difficult to implement. We've planned to implement it sometime in the next couple of months.
@hturki is there an example of a custom backward function you want to use with functorch?
Thanks for the update. I'd actually like to hook up a third-party library (tiny-cuda-nn: https://github.com/NVlabs/tiny-cuda-nn/blob/master/bindings/torch/tinycudann/modules.py#L41) which has pytorch bindings. It just currently happens to integrate with torch via the current autograd tooling - if there ends up being a workaround that directly integrates with functorch that would work as well
Thanks for the example. There are two types of autograd.Function:
That example looks like the latter. We don't support it yet, but, in the future, if someone wants to vmap over an autograd.Function whose forward and backward pass are both custom CUDA kernels, then they will need to write a custom batching rule so that vmap knows how to work with it.
We've finally gotten to revisiting this. I'm curious to hear about other use cases for this feature so we know what to prioritize. Calling on some folks (but please feel free to comment if you have thoughts!): @rejuvyesh, could you tell us a bit about your use case please?
For me, integration with vmap
is less important. However, functorch
's explicit interface is much nicer to generally integrate with other autodiff systems like those in say Julia. I use it in PyCallChainRules.jl so that users can integrate their differentiable pytorch functions in Julia. Biggest use cases is actually wrapping large cuda kernels which usually have a pytorch interface and dispatch performance is usually not that big of a concern. However, as @hturki found, it won't work with tiny-cuda-nn for example.
For me, integration with vmap is less important. However, functorch's explicit interface is much nicer to generally integrate with other autodiff systems like those in say Julia. I use it in PyCallChainRules.jl so that users can integrate their differentiable pytorch functions in Julia
Thanks for the feedback
@hturki (and @rejuvyesh if you have thoughts) - for hooking up a third-party library like tiny-cuda-nn: https://github.com/NVlabs/tiny-cuda-nn/blob/master/bindings/torch/tinycudann/modules.py#L41, there is a caveat for jacobian computation (and integration with vmap).
My understanding is that third-party libraries may define cuda kernels for the forward pass and backward pass that are stitched together via autograd.Function.
In order to compute its jacobian, we would need to be able to vmap over the backward pass. So functorch would provide a mechanism to add a vmap (batching) rule for the backward pass (like how autograd.Function allows you to define a custom gradient rule) and the user would have to add the batching rule. The batching rule would either be something naive (like running the backward pass in a for-loop), a call into a pre-written cuda kernel in the third-party library (sometimes one just has to finagle the inputs), or, in the worst case if the user really wanted something optimized, a new cuda kernel.
I just wanted to check - would something like this be helpful? I imagine writing a new cuda kernel would be a lot of work, but just the naive approach (running the backward pass in a for-loop) could still let you use the vmap API (and get speedups from other parts of the model)
Hello, I've recently encountered a similar problem too, where I want to use functorch to compute gradient pass a customised torch.autograd.Function. But in my case, all operations in this customized Function, includeing the forward and backward pass are all implemented with pytorch operations. Is there any method that can make the functorch work in these circumstances?
Hi @zou3519,
I have one small example if you are interested :
import torch
import torch.nn as nn
from torch.func import grad
class ReLUAlphaFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, alpha):
ctx.save_for_backward(input)
ctx.alpha = alpha
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
grad_input[input == 0] = ctx.alpha
return grad_input, None
class ReLUAlpha(nn.Module):
def __init__(self, alpha):
super(ReLUAlpha, self).__init__()
self.alpha = alpha
def forward(self, input):
return ReLUAlphaFunction.apply(input, self.alpha)
relu_10 = lambda: ReLUAlpha(10)
input = torch.ones(1, 1, 4, 4)
gradient = grad(ReLUAlpha(10))(input)
RuntimeError: functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. Please rewrite your function to not use autograd.Function while we work on fixing this
Thanks for your help.
@ryanboustany if you're feeling ambitious, we're testing out the feature on the PyTorch master branch and would love early feedback. It's gated behind a feature flag and you'll have to download the latest PyTorch nightly binary to try it out. Here's what your script would look like with it (https://gist.github.com/zou3519/6e67fa31a8ebedfd1f198dc21e5d5993); nightly binary installation instructions over at https://pytorch.org/
@ryanboustany if you're feeling ambitious, we're testing out the feature on the PyTorch master branch and would love early feedback. It's gated behind a feature flag and you'll have to download the latest PyTorch nightly binary to try it out. Here's what your script would look like with it (https://gist.github.com/zou3519/6e67fa31a8ebedfd1f198dc21e5d5993); nightly binary installation instructions over at https://pytorch.org/
Thank you for your helpful answer.
1- I have tested the ReLU custom gradient. Functions Grad and Jacrev from Jax work very well. However, for jacfwd, it blocks : ''NotImplementedError: You must implement the jvp function for custom autograd.Function to use it with forward mode AD.''
I wonder what the syntax is to be able to use jacfwd and code the custom gradient for forward AD. It would be interesting to compare what happens in forward AD vs reverse AD for custom gradients. Note that grad, jacrev, jacfwd works well if you use torch.nn.ReLU instead of the custom gradient.
2- I tested with a custom gradient of MaxPooling.
class MyMaxPool(torch.autograd.Function):
generate_vmap_rule=True
@staticmethod
def forward(input, kernel_size):
N, C, H, W = input.size()
return torch.amax(torch.amax(input.reshape(N, C, H //kernel_size , kernel_size, W // kernel_size, kernel_size),3),4)
@staticmethod
def setup_context(ctx, inputs, output):
input, kernel_size = inputs
ctx.save_for_backward(input)
ctx.kernel_size = kernel_size
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
N, C, H, W = input.size()
x_reshaped = input.reshape(N, C, H // ctx.kernel_size, ctx.kernel_size, W // ctx.kernel_size, ctx.kernel_size)
out = torch.amax(torch.amax(x_reshaped,3),4)
dx_reshaped = torch.zeros(x_reshaped.size())
out_newaxis = out[:, :, :, None, :, None]
mask = (x_reshaped == out_newaxis)
dout_newaxis = grad_output[:, :, :, None, :, None]
dout_broadcast = torch.broadcast_to(dout_newaxis, dx_reshaped.size())
dx_reshaped[mask] = dout_broadcast[mask]
grad_x = dx_reshaped.reshape(input.size())
return grad_x, None, None
class MaxPoolings(nn.Module):
def __init__(self, kernel_size):
super(MaxPoolings, self).__init__()
self.kernel_size = kernel_size
def forward(self, input):
return MyMaxPool.apply(input, self.kernel_size)
input = torch.randn(1, 1, 4, 4)
maxpool = MaxPoolings(2)
gradient = jacrev(lambda x: maxpool(x).sum())(input)
RuntimeError: vmap: index_put_(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of index_put_. If said operator is being called inside the PyTorch framework, please file a bug report instead.
I don't understand why this is blocking. Note that the grad function works well! Moreover, I wonder what I have to do to make jacfwd work too...
Thanks again. I hope this can help you.
1- I have tested with ReLU custom gradient. Functions Grad and Jacrev from Jax work very well. However, for jacfwd, it blocks : ''NotImplementedError: You must implement the jvp function for custom autograd.Function to use it with forward mode AD.''
See https://pytorch.org/docs/stable/notes/extending.html#forward-mode-ad for context. The TL;DR is that you'd need to add a jvp
staticmethod that defines what the forward-mode AD gradient should be.
That being said, support for jacfwd is actually still a WIP (you'll need to wait until after https://github.com/pytorch/pytorch/pull/91211/ is merged, will likely be later today/tomorrow), there is an example here for how to write a jvp staticmethod.
2- I tested with a custom gradient of MaxPooling.
Change dx_reshaped = torch.zeros(x_reshaped.size())
to dx_reshaped = grad_output.new_zeros(x_reshaped.size())
. A longer explanation is over at https://pytorch.org/functorch/stable/ux_limitations.html#mutation-in-place-pytorch-operations, but this is a vmap limitation: we need to construct a new Tensor using the original tensor as a base
A new potential use case pytorch/pytorch#91629 ? @zou3519
@zou3519 I tried to add a jvp method. But it doesn't seem to give the desired result. Do you have an idea? Do you know what the output of the jvp method returns? Here it doesn't return my output_tangent matrix.
Thank you.
import torch
import torch.nn as nn
from torch.func import grad, jacrev, jacfwd
torch._C._set_autograd_function_extension_enabled(True)
class ReLUAlphaFunction(torch.autograd.Function):
generate_vmap_rule=True
@staticmethod
def forward(input, alpha):
return input.clamp(min=0)
@staticmethod
def setup_context(ctx, inputs, output):
input, alpha = inputs
ctx.save_for_backward(input)
ctx.save_for_forward(input)
ctx.alpha = alpha
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
grad_input[input == 0] = ctx.alpha
return grad_input, None
@staticmethod
def jvp(ctx, *tangents):
input, = ctx.saved_tensors
output_tangent = input.clone()
output_tangent[input < 0] = 0
output_tangent[input > 0] = 1
output_tangent[input == 0] = ctx.alpha
return output_tangent
class ReLUAlpha(nn.Module):
def __init__(self, alpha):
super(ReLUAlpha, self).__init__()
self.alpha = alpha
def forward(self, input):
return ReLUAlphaFunction.apply(input, self.alpha)
input = torch.tensor([[0.0,2.0, 3.0],[1.0,2.0, 3.0]])
relu_10 = ReLUAlpha(10)
print(jacfwd(lambda x: relu_10(x).sum())(input))
print(jacrev(lambda x: relu_10(x).sum())(input))
Output : tensor([[15., 15., 15.], [15., 15., 15.]])
tensor([[10., 1., 1.], [ 1., 1., 1.]])
@ryanboustany the jvp method should look like this: https://gist.github.com/zou3519/89790dc3e94e3c3fde313886143c06fa (you cloned the wrong tensor).
That being said, the result from jacfwd is different from the result of jacrev. I think this is expected because the autograd.Function isn't computing mathematically correct gradients.
EDIT: I was able to get the jacrev and jacfwd to return the same value. I think this is what you want. (https://gist.github.com/zou3519/726e7929e63b3696c0b97d4e3ae852a1).
For some more context, the jvp method is supposed to compute the jacobian-vector product. That's the jacobian of the Function evaluated at the input matrix-multiplied with the tangent (the "vector"). Underneath vmap, jacfwd(relu_10)(input)
essentially does a vmap over the jvp in such a way that the output of autograd.Function.jvp is the full jacobian
@zou3519 Indeed, your code does what I want to do. For me, jacrev
and jacfwd
are supposed to compute the same thing (modulo the numerical precision - 32 bits
or 64 bits
). In fact, they compute the same thing in theory but do not do the same numerical operations. Here is an example:
class ReLUAlphaFunction(torch.autograd.Function):
generate_vmap_rule=True
@staticmethod
def forward(input, alpha):
return input.clamp(min=0)
@staticmethod
def setup_context(ctx, inputs, output):
input, alpha = inputs
ctx.save_for_backward(input)
ctx.save_for_forward(input)
ctx.alpha = alpha
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
grad_input = torch.where((input == 0) * (grad_input != 0), ctx.alpha, grad_input)
return grad_input, None
@staticmethod
def jvp(ctx, *tangents):
input, = ctx.saved_tensors
tangent, _ = tangents
output_tangent = tangent.clone()
output_tangent[input < 0] = 0
output_tangent = torch.where((input == 0) * (output_tangent != 0), ctx.alpha, output_tangent)
#import pdb; pdb.set_trace()
return output_tangent
class ReLUAlpha(nn.Module):
def __init__(self, alpha):
super(ReLUAlpha, self).__init__()
self.alpha = alpha
def forward(self, input):
return ReLUAlphaFunction.apply(input, self.alpha)
relu_10 = ReLUAlpha(10)
model = torch.nn.Sequential(
nn.Linear(10, 10, bias=False),
relu_10,
nn.Linear(10,2, bias=False))
def f(x):
return model(x).sum()
x = torch.randn(1,1,10,10)
x[0][0][0][0] = 0.0
x.requires_grad_()
loss = f(x)
loss.backward()
jacobian_fwd = jacfwd(f)(x)
jacobian_bwd = jacrev(f)(x)
jacobian_grad = grad(f)(x)
print(torch.linalg.norm(x.grad - jacobian_fwd))
print(torch.linalg.norm(jacobian_bwd - jacobian_fwd))
print(torch.linalg.norm(jacobian_grad - jacobian_fwd))
Output :
tensor(1.1031e-07, grad_fn=
Anyway, customizing gradient and using jacrev/jacfwd
works very well! Thanks a lot!
Last thing, we should succeed in calculating jacrev
and jacfwd
with respect to the training parameters of the model and not only with respect to the input. Something that would make this functionality :
grad_neural1 = [param.grad for param in model.parameters()].
I don't have the solution yet.
grad_neural1_fwd = [jacfwd(f)(param) for param in model.parameters()]
Seem doesn't work. It considers param to be an input and not a weight.
@ryanboustany one needs to write a function that accepts parameters in order to compute jacobians for those parameters. The API to do that in functorch is make_functional: https://pytorch.org/functorch/stable/generated/functorch.make_functional.html?highlight=make_functional#functorch.make_functional
Also, if you need more support, please open another issue and let's talk more there -- I'm about to mark this one as complete :)
To everyone else in this thread:
autograd.Function now works with functorch! The functionality is now in the PyTorch nightly binaries. The TL;DR is that you'll have to refactor existing autograd.Functions in order to use them.
Please see https://pytorch.org/docs/master/notes/extending.func.html for more details. We'd love to get some more feedback on this, so please open more issues if you run into problems.
We should fix it. Not quite sure how though.
https://gist.github.com/zou3519/73a6189e21561f6ef5b42874e8a4826f