Open kfxw opened 3 years ago
@kfxw what kind of model are you using? Depending on the model we might point you to different solutions.
To compute per-sample-gradients using vmap, we cannot call torch.autograd.grad inside of vmap, which I think is what you're trying to do. This is a limitation with how autograd works. Instead, we need to reframe the problem using functorch's grad
transform.
There's an example on how to compute per-sample-gradients of a very simple model (that is just one nn.Linear) here: https://github.com/facebookresearch/functorch#working-with-nn-modules-make_functional-and-friends. I copied the example below:
import torch
from functorch import make_functional, vmap, grad
model = torch.nn.Linear(3, 3)
data = torch.randn(64, 3)
targets = torch.randn(64, 3)
func_model, params = make_functional(model)
def compute_loss(params, data, targets):
preds = func_model(params, data)
return torch.mean((preds - targets) ** 2)
per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets)
Note that we use functorch's grad
operator to compute gradients and then use vmap
to batch over them to produce per-sample-gradients.
Thanks for your reply!
I've taken a look at the grad
example. I think using the combination of vmap
and grad
means that you want the same data division / parallelization for both forward and backward pass.
However, in my case this could be different. Briefly speaking, the forward pass is the conventional mini-batching, parallelizing on the batch dim. The backward pass is going to compute the diagonal of the jacobian (and hessian) matrix, parallelizing on more dims. grad
is not capable to handle the different parallelizations. This kind of computation is usually seen for physics-related models. I can show you an unaccelerated example below.
# input data, the 1D coordinates X with the shape [batch_size, num_points]
# e.g. X.shape --> [3, 500]
pred = model(X) # parallelized on the 'batch_size' dim
sum_1st_grad = torch.zeros(3)
# how to speed-up this loop
for b in range(3):
for i in range(500):
sum_1st_grad[b] += torch.autograd.grad(pred[b,i], X[b,i])[0] # d pred_i / d x_i
As for more information about the model, it can be a Conv1d or a graph conv imported from pytorch_geometric. So the non-diagonal values of the jacobian matrix d pred_i / d x_j
may not be zero.
Thanks for your time!
Hi @kfxw, thank you for the clarification! You're right -- the example I provided above is different from what you're looking for.
Does something like the following reflect the semantics you want? (It uses nn.Linear, but you can plug in Conv1d or something else) It sounds like for each element in the batch (of size 500), you want the summed diagonal of the jacobian (or hessian)
import torch
from functorch import make_functional, vmap, jacrev
X = torch.randn(3, 500)
model = torch.nn.Linear(500, 500)
# Size [3, 500, 500]
jacs = vmap(jacrev(model))(X)
# Get each diagonal and sum them.
sum_1st_grad = jacs.diagonal(1, 2).sum(-1)
NB: we might be missing performant support for conv1d, but we're working on it and can prioritize it.
Actually, we might need one additional step. The item passed to vmap(jacrev
assumes that the model isn't parallelized on the batch_size dim. We can resolve this by unsqueezing and squeezing a dimension:
import torch
from functorch import make_functional, vmap, jacrev
X = torch.randn(3, 500)
model = torch.nn.Linear(500, 500)
def unparallelized_model(x):
return model(x.unsqueeze(0)).squeeze(0)
# Size [3, 500, 500]
jacs = vmap(jacrev(unparallelized_model))(X)
# Get each diagonal and sum them.
sum_1st_grad = jacs.diagonal(1, 2).sum(-1)
Thanks for your suggestions.
Yes, it is something I want. In fact, I followed the example on pytorch doc to get the vjp and it worked well on simple computational graphs. Is it because now you are supporting a limited number of operators that, when I use the torch_geometric.nn.SplineConv
in the model, I got a RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
error? And is there any differences between using your example and using the one on the pytorch doc?
Is it because now you are supporting a limited number of operators that, when I use the torch_geometric.nn.SplineConv in the model, I got a RuntimeError: Cannot access data pointer of Tensor that doesn't have storage error?
That's a bug on our side that can be fixed.
And is there any differences between using your example and using the one on the pytorch doc?
If the one in the doc works for you, then go for it! I would recommend replacing torch.vmap with functorch.vmap -- functorch.vmap is an updated version that supports more operators.
>>> # vectorized gradient computation
>>> def get_vjp(v):
>>> return torch.autograd.grad(y, x, v)
>>> jacobian = torch.vmap(get_vjp)(I_N)
Sounds great! It would be perfect if you can fix this bug! Looking forward to it. I would be grateful if you can mention this issue again or notify me when the bug is fixed. Thanks!
Let's keep the issue open to track this.
Hi,
I've got some troubles with vmap in my code. I wanted to use vmap to accelerate the element wise gradient computation but found vmap destroyed the computation graph inside. Below I show the minimal reproduction code. I would like to know whether it is something expected. Or is there any alternative way to compute element wise gradients efficiently? (A detailed post for my need can be viewed here (pytorch forum)).
to reproduce:
Thanks!