Closed cyyever closed 1 year ago
Thanks for the bug report and the repro script. I was able to reproduce this and extracted a smaller repro. Something is wrong from our side, we'll look into it
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd.forward_ad as fwAD
conv_weight = torch.randn(6, 1, 30, 30)
def model(weights, x):
conv_weight = weights
x = F.conv2d(x, conv_weight)
x = x.view(x.size(0), -1)
return x
def loss_fun(param, input_tensor):
target = torch.LongTensor([1])
out = model(param, input_tensor)
return F.log_softmax(out).sum()
input_tensor = torch.rand((1, 1, 32, 32))
vector = torch.ones_like(input_tensor)
from functorch import grad, jacfwd, jacrev, jvp, make_functional
def grad_f(input_tensor):
return grad(loss_fun)(conv_weight, input_tensor)
print(jvp(grad_f, (input_tensor,), (vector,)))
I have root-caused this to https://github.com/pytorch/pytorch/issues/81111
@cyyever this has been fixed in PyTorch and will be in the next release. If you want to use it earlier, please try a PyTorch nightly (and build functorch from source)
Closing because this has been resolved
As the title says, I tried to combine grad and jvp and then pytorch said
I compile functorch and pytorch from master versions yesterday so this should be an unresolved issue. I managed to come up with a minimal script:
But it fails with