Open sgstepaniants opened 1 year ago
Hi @sgstepaniants!
This should be possible by doing vmap(vmap(f, in_dims=(0, None)), in_dims=(None, 0))(x, v)
For a little more color, what in_dims
is saying here is where the batch dimension of each tensor is. 0 means that it's the 0th dimension, None means that there is no batch dimension. So in the innermost vmap, we consider x batched and v not batched. In the outer one, vice versa. If the batch dimensions are not the 0th dimension, you can use another int to represent the correct dimension (for instance -1)
Let us know if this doesn't work for you or if you have any other questions
Thank you! So I ended up doing something like this to get jvp and vjp batched over inputs and tangent vectors. Does this make sense?
d = 5
D = 16
weight = torch.randn(D, d)
bias = torch.randn(D)
xs = torch.randn(2, d)
us = torch.randn(10, d)
vs = torch.randn(10, D)
def predict(x):
return F.linear(x, weight, bias).tanh()
unit_vectors = torch.eye(D)
def compute_jac(xp):
jacobian_rows = [torch.autograd.grad(predict(xp), xp, vec)[0]
for vec in unit_vectors]
return torch.stack(jacobian_rows)
k = 0
xp = xs[k].clone().requires_grad_()
jacobian = compute_jac(xp).detach()
def compute_jvp(x, v):
return jvp(predict, (x,), (v,))[1]
ft_jvp = vmap(vmap(compute_jvp, in_dims=(None, 0)), in_dims=(0, None))(xs, us)
def compute_vjp(x):
return vmap(vjp(predict, x)[1])(vs)
ft_vjp, = vmap(compute_vjp)(xs)
print(torch.norm(ft_vjp[0, :, :] - vs @ jacobian))
print(torch.norm(ft_jvp[0, :, :] - us @ jacobian.T))
Actually this is a lot faster
ft_jacobians = vmap(jacrev(predict))(xs)
ft_jvp2 = torch.einsum("ikl, jl -> ijk", ft_jacobians, us)
ft_vjp2 = torch.einsum("ikl, jk -> ijl", ft_jacobians, vs)
print(torch.norm(ft_jvp2[0, :, :] - us @ jacobian.T))
print(torch.norm(ft_vjp2[0, :, :] - vs @ jacobian))
I want to compute the jacobian vector product of a function F from R^d to R^D. But I need to do this at a batch of points x_1, ..., x_n in R^d and a batch of tangent vectors v_1, ..., v_m in R^d. Namely, for all i = 1, ..., n and j = 1, ..., m I need to compute the nxm jacobian vector products: J_F(x_i) * v_j.
Is there a way to do this by using vmap twice to loop over the batches x_i and v_j?