pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Question on how to batch over both: inputs and tangent vectors #940

Open sgstepaniants opened 1 year ago

sgstepaniants commented 1 year ago

I want to compute the jacobian vector product of a function F from R^d to R^D. But I need to do this at a batch of points x_1, ..., x_n in R^d and a batch of tangent vectors v_1, ..., v_m in R^d. Namely, for all i = 1, ..., n and j = 1, ..., m I need to compute the nxm jacobian vector products: J_F(x_i) * v_j.

Is there a way to do this by using vmap twice to loop over the batches x_i and v_j?

samdow commented 1 year ago

Hi @sgstepaniants!

This should be possible by doing vmap(vmap(f, in_dims=(0, None)), in_dims=(None, 0))(x, v)

For a little more color, what in_dims is saying here is where the batch dimension of each tensor is. 0 means that it's the 0th dimension, None means that there is no batch dimension. So in the innermost vmap, we consider x batched and v not batched. In the outer one, vice versa. If the batch dimensions are not the 0th dimension, you can use another int to represent the correct dimension (for instance -1)

Let us know if this doesn't work for you or if you have any other questions

sgstepaniants commented 1 year ago

Thank you! So I ended up doing something like this to get jvp and vjp batched over inputs and tangent vectors. Does this make sense?

d = 5
D = 16
weight = torch.randn(D, d)
bias = torch.randn(D)
xs = torch.randn(2, d)
us = torch.randn(10, d)
vs = torch.randn(10, D)

def predict(x):
    return F.linear(x, weight, bias).tanh()

unit_vectors = torch.eye(D)
def compute_jac(xp):
    jacobian_rows = [torch.autograd.grad(predict(xp), xp, vec)[0]
                     for vec in unit_vectors]
    return torch.stack(jacobian_rows)
k = 0
xp = xs[k].clone().requires_grad_()
jacobian = compute_jac(xp).detach()

def compute_jvp(x, v):
    return jvp(predict, (x,), (v,))[1]
ft_jvp = vmap(vmap(compute_jvp, in_dims=(None, 0)), in_dims=(0, None))(xs, us)

def compute_vjp(x):
    return vmap(vjp(predict, x)[1])(vs)
ft_vjp, = vmap(compute_vjp)(xs)

print(torch.norm(ft_vjp[0, :, :] - vs @ jacobian))
print(torch.norm(ft_jvp[0, :, :] - us @ jacobian.T))
sgstepaniants commented 1 year ago

Actually this is a lot faster

ft_jacobians = vmap(jacrev(predict))(xs)
ft_jvp2 = torch.einsum("ikl, jl -> ijk", ft_jacobians, us)
ft_vjp2 = torch.einsum("ikl, jk -> ijl", ft_jacobians, vs)

print(torch.norm(ft_jvp2[0, :, :] - us @ jacobian.T))
print(torch.norm(ft_vjp2[0, :, :] - vs @ jacobian))