Closed macio232 closed 2 years ago
Thanks for the detailed repro, @macio232.
This is expected and we have a detailed error message for this already on the main branch. torch.autograd.functional has a design limitation that makes it so that we are unable to vmap over it in many cases (including the one in your script). The workaround is to use functorch.vjp instead of torch.autograd.functional.vjp (the functorch variants of torch.autograd.functional APIs are arbitrarily composable with each other).
Rewriting your example gives the following, which does not error for me:
import torch
from functorch import vmap, vjp
def exp_reducer(x):
return x.exp().sum(dim=1)
inputs = torch.rand(10, 4, 4)
v = torch.ones(4)
def compute_vjp(inputs, v):
_, vjp_fn = vjp(exp_reducer, inputs)
result, = vjp_fn(v)
return result
vmaped_vjp = vmap(compute_vjp, in_dims=(0, None))
manual_vjp = []
for batch in inputs:
manual_vjp.append(torch.autograd.functional.vjp(exp_reducer, batch, v)[1])
result = vmaped_vjp(inputs, v)
expected = torch.stack(manual_vjp, dim=0)
assert torch.allclose(result, expected)
Thank you for the explanation and solution!
Env: torch==1.11.0+cu113 functorch==0.1.1
Problem: I was trying to extend https://github.com/rtqichen/torchdiffeq/blob/8df757cb12f231a6b4349a96608b7a9d11166988/torchdiffeq/_impl/odeint.py#L130 for batching and
vmap
seems to make a perfect fit here. Unfortunatelybackward
involvestorch.autograd.functional.vjp
execution which apparently is not compatible withBatchedTensor
that appears withvmap
.Code example:
With the same behavior on GPU.
In reference to https://github.com/pytorch/pytorch/issues/42368#issuecomment-1168884014