pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

torch.autograd.functional.vjp returns 0 gradients when executed in vmap #922

Closed macio232 closed 2 years ago

macio232 commented 2 years ago

Env: torch==1.11.0+cu113 functorch==0.1.1

Problem: I was trying to extend https://github.com/rtqichen/torchdiffeq/blob/8df757cb12f231a6b4349a96608b7a9d11166988/torchdiffeq/_impl/odeint.py#L130 for batching and vmap seems to make a perfect fit here. Unfortunately backward involves torch.autograd.functional.vjp execution which apparently is not compatible with BatchedTensor that appears with vmap.

Code example:

import torch
from functorch import vmap

def exp_reducer(x):
    return x.exp().sum(dim=1)

inputs = torch.rand(10, 4, 4)
v = torch.ones(4)

vmaped_vjp = vmap(
    torch.autograd.functional.vjp,
    in_dims=(None, 0, None),
    out_dims=(0)
)

manual_vjp = []
for batch in inputs:
    manual_vjp.append(torch.autograd.functional.vjp(exp_reducer, batch, v)[1])
assert (
        vmaped_vjp(exp_reducer, inputs, v)[1] == torch.stack(manual_vjp, dim=0)
).all()

With the same behavior on GPU.

In reference to https://github.com/pytorch/pytorch/issues/42368#issuecomment-1168884014

zou3519 commented 2 years ago

Thanks for the detailed repro, @macio232.

This is expected and we have a detailed error message for this already on the main branch. torch.autograd.functional has a design limitation that makes it so that we are unable to vmap over it in many cases (including the one in your script). The workaround is to use functorch.vjp instead of torch.autograd.functional.vjp (the functorch variants of torch.autograd.functional APIs are arbitrarily composable with each other).

Rewriting your example gives the following, which does not error for me:

import torch
from functorch import vmap, vjp

def exp_reducer(x):
    return x.exp().sum(dim=1)

inputs = torch.rand(10, 4, 4)
v = torch.ones(4)

def compute_vjp(inputs, v):
    _, vjp_fn = vjp(exp_reducer, inputs)
    result, = vjp_fn(v)
    return result

vmaped_vjp = vmap(compute_vjp, in_dims=(0, None))

manual_vjp = []
for batch in inputs:
    manual_vjp.append(torch.autograd.functional.vjp(exp_reducer, batch, v)[1])

result = vmaped_vjp(inputs, v)
expected = torch.stack(manual_vjp, dim=0)
assert torch.allclose(result, expected)
macio232 commented 2 years ago

Thank you for the explanation and solution!