pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

The computational graph will be destroyed after feeding tensor/list of tensors into vmap #175

Open kfxw opened 3 years ago

kfxw commented 3 years ago

Hi,

I've got some troubles with vmap in my code. I wanted to use vmap to accelerate the element wise gradient computation but found vmap destroyed the computation graph inside. Below I show the minimal reproduction code. I would like to know whether it is something expected. Or is there any alternative way to compute element wise gradients efficiently? (A detailed post for my need can be viewed here (pytorch forum)).

to reproduce:

import torch
from functorch import vmap

# vmap on tensor
a = torch.rand(5).cuda()
a.requires_grad = True
b = a * 100

print(a.requires_grad, b.requires_grad)

def test1(a,b):
    return [torch.tensor(a.requires_grad), torch.tensor(b.requires_grad)]
    # return torch.autograd.grad(a, b, retain_graph=True, create_graph=True)[0]

print(vmap(test1)(b,a))

# vmap on list of tensor
a1 = list(a.chunk(5))
b1 = list(b.chunk(5))
for i in range(5):
    print(a1[i].requires_grad, b1[i].requires_grad)

def test2(a,b):
    return [torch.tensor(a[0].requires_grad), torch.tensor(b[0].requires_grad)]
    # return torch.autograd.grad(a[0], b[0], retain_graph=True, create_graph=True)[0]

print(vmap(test2)(b1,a1))

Thanks!

zou3519 commented 3 years ago

@kfxw what kind of model are you using? Depending on the model we might point you to different solutions.

To compute per-sample-gradients using vmap, we cannot call torch.autograd.grad inside of vmap, which I think is what you're trying to do. This is a limitation with how autograd works. Instead, we need to reframe the problem using functorch's grad transform.

There's an example on how to compute per-sample-gradients of a very simple model (that is just one nn.Linear) here: https://github.com/facebookresearch/functorch#working-with-nn-modules-make_functional-and-friends. I copied the example below:

import torch
from functorch import make_functional, vmap, grad

model = torch.nn.Linear(3, 3)
data = torch.randn(64, 3)
targets = torch.randn(64, 3)

func_model, params = make_functional(model)

def compute_loss(params, data, targets):
    preds = func_model(params, data)
    return torch.mean((preds - targets) ** 2)

per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets)

Note that we use functorch's grad operator to compute gradients and then use vmap to batch over them to produce per-sample-gradients.

kfxw commented 3 years ago

Thanks for your reply!

I've taken a look at the grad example. I think using the combination of vmap and grad means that you want the same data division / parallelization for both forward and backward pass.

However, in my case this could be different. Briefly speaking, the forward pass is the conventional mini-batching, parallelizing on the batch dim. The backward pass is going to compute the diagonal of the jacobian (and hessian) matrix, parallelizing on more dims. grad is not capable to handle the different parallelizations. This kind of computation is usually seen for physics-related models. I can show you an unaccelerated example below.

# input data, the 1D coordinates X with the shape [batch_size, num_points]
#   e.g. X.shape --> [3, 500]
pred = model(X)     # parallelized on the 'batch_size' dim
sum_1st_grad = torch.zeros(3)
# how to speed-up this loop
for b in range(3):
    for i in range(500):
        sum_1st_grad[b] += torch.autograd.grad(pred[b,i], X[b,i])[0]     # d pred_i / d x_i

As for more information about the model, it can be a Conv1d or a graph conv imported from pytorch_geometric. So the non-diagonal values of the jacobian matrix d pred_i / d x_j may not be zero.

Thanks for your time!

zou3519 commented 3 years ago

Hi @kfxw, thank you for the clarification! You're right -- the example I provided above is different from what you're looking for.

Does something like the following reflect the semantics you want? (It uses nn.Linear, but you can plug in Conv1d or something else) It sounds like for each element in the batch (of size 500), you want the summed diagonal of the jacobian (or hessian)

import torch
from functorch import make_functional, vmap, jacrev

X = torch.randn(3, 500)
model = torch.nn.Linear(500, 500)

# Size [3, 500, 500]
jacs = vmap(jacrev(model))(X)

# Get each diagonal and sum them.
sum_1st_grad = jacs.diagonal(1, 2).sum(-1)

NB: we might be missing performant support for conv1d, but we're working on it and can prioritize it.

zou3519 commented 3 years ago

Actually, we might need one additional step. The item passed to vmap(jacrev assumes that the model isn't parallelized on the batch_size dim. We can resolve this by unsqueezing and squeezing a dimension:

import torch
from functorch import make_functional, vmap, jacrev

X = torch.randn(3, 500)
model = torch.nn.Linear(500, 500)

def unparallelized_model(x):
    return model(x.unsqueeze(0)).squeeze(0)

# Size [3, 500, 500]
jacs = vmap(jacrev(unparallelized_model))(X)

# Get each diagonal and sum them.
sum_1st_grad = jacs.diagonal(1, 2).sum(-1)
kfxw commented 2 years ago

Thanks for your suggestions.

Yes, it is something I want. In fact, I followed the example on pytorch doc to get the vjp and it worked well on simple computational graphs. Is it because now you are supporting a limited number of operators that, when I use the torch_geometric.nn.SplineConv in the model, I got a RuntimeError: Cannot access data pointer of Tensor that doesn't have storage error? And is there any differences between using your example and using the one on the pytorch doc?

zou3519 commented 2 years ago

Is it because now you are supporting a limited number of operators that, when I use the torch_geometric.nn.SplineConv in the model, I got a RuntimeError: Cannot access data pointer of Tensor that doesn't have storage error?

That's a bug on our side that can be fixed.

And is there any differences between using your example and using the one on the pytorch doc?

If the one in the doc works for you, then go for it! I would recommend replacing torch.vmap with functorch.vmap -- functorch.vmap is an updated version that supports more operators.

>>> # vectorized gradient computation
>>> def get_vjp(v):
>>>     return torch.autograd.grad(y, x, v)
>>> jacobian = torch.vmap(get_vjp)(I_N)
kfxw commented 2 years ago

Sounds great! It would be perfect if you can fix this bug! Looking forward to it. I would be grateful if you can mention this issue again or notify me when the bug is fixed. Thanks!

zou3519 commented 2 years ago

Let's keep the issue open to track this.