Closed JoaoLages closed 1 year ago
Hi @JoaoLages, if you want to compute multiple losses from a single sample you should be able to do this via forward-mode AD. I've written a brief example below,
import torch
from torch import nn
import functorch
from functorch import jacfwd, make_functional, vmap
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.fc=nn.Linear(2, 4)
def forward(self, x):
y=self.fc(x)
return y, y.pow(2).sum(-1).sqrt() #our two losses (output, and Euclidean norm)
B=100 #batch size
N=2 #input size
x = torch.randn(B, N)
func=model()
fnet, params = make_functional(func)
dy_dx, dnorm_dx = vmap(jacfwd(fnet, argnums=1), in_dims=(None, 0))(params, x)
print(dy_dx.shape, dnorm_dx.shape) #returns torch.Size([100, 4, 2]) torch.Size([100, 2])
As can be seen from the example both the jacobian of the nn.Linear
layer and the gradient of its norm are computed in parallel, and it's all vectorized via vmap!
(For completeness this issue branches from a forum post, found here)
In my case the model
is very heavy and I can only call model.forward
with a single sample.
I would like to first compute Y = model.forward(X)
and then compute the gradient of the multiple scalar losses that are in Y
, do you see?
In my case the
model
is very heavy and I can only callmodel.forward
with a single sample. I would like to first computeY = model.forward(X)
and then compute the gradient of the multiple scalar losses that are inY
, do you see?
I don't see why that'd be an issue, if you define a function that evaluates your model
and then computes all scalar loss and returns then as outputs you should be able to re-use the example I wrote above (without using vmap of course).
Also, try and attach a minimal reproducible example so people can debug the problem!
Imagine a text to image model, which exactly my use case.
import torch
from torch import nn
import functorch
from functorch import jacfwd, make_functional, vmap
text = "a photo of an astronaut riding a horse on mars"
X = get_input_embeddings(text) # output [1, embedding_dim]
Y = text2imagemodel(X) # shape [1, width, height]
How do I calculate the gradient of each pixel in Y with respect to the input embeddings X
?
Have you tried using jacrev or jacfwd?
Something like,
fnet, params = make_functional(text2imagemodel) #or the buffers version if your model uses buffers
gradient = vmap(jacrev(fnet, argnums=0), in_dims=(None, 0))(params, X)
Have you tried using jacrev or jacfwd?
Something like,
fnet, params = make_functional(text2imagemodel) #or the buffers version if your model uses buffers gradient = vmap(jacrev(fnet, argnums=0), in_dims=(None, 0))(params, X)
Thanks for your help! 🙏 Will try to have a look at it again tomorrow and will come back to this issue :)
Have you tried using jacrev or jacfwd?
Something like,
fnet, params = make_functional(text2imagemodel) #or the buffers version if your model uses buffers gradient = vmap(jacrev(fnet, argnums=0), in_dims=(None, 0))(params, X)
jacrev
ran the whole forward pass of the model, but then exploded in RAM (and I am with 350GB of RAM 😬, and yes I'm running on CPU)
When replacing jacrev
with jacfwd
, the following error occurs:
RuntimeError: vmap: called random operation while in randomness error mode. Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap
Even though I already had randomness='same'
in my call gradient = vmap(jacrev(model_forward_func), randomness="same")(X.unsqueeze(0))
At the end of the day I just wanted torch.autograd.grad
to return me the gradients per scalar and not their sum 😩
Have you tried using jacrev or jacfwd? Something like,
fnet, params = make_functional(text2imagemodel) #or the buffers version if your model uses buffers gradient = vmap(jacrev(fnet, argnums=0), in_dims=(None, 0))(params, X)
jacrev
ran the whole forward pass of the model, but then exploded in RAM (and I am with 350GB of RAM grimacing, and yes I'm running on CPU) When replacingjacrev
withjacfwd
, the following error occurs:RuntimeError: vmap: called random operation while in randomness error mode. Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap
Even though I already had
randomness='same'
in my callgradient = vmap(jacrev(model_forward_func), randomness="same")(X.unsqueeze(0))
At the end of the day I just wanted
torch.autograd.grad
to return me the gradients per scalar and not their sum weary
It should be argnums=1
not 0,
gradient = vmap(jacrev(fnet, argnums=1), in_dims=(None, 0))(params, X)
as that's for the parameters and not the input. Try that and see if it works.
, argnums=1
Does it make a difference to pass params
to vmap
? Because my model is composed of more than 1 nn.Module
, so I don't pass the params
to vmap
, they are inside my model_forward_func
function:
gradient = vmap(jacrev(model_forward_func), randomness="same")(X.unsqueeze(0))
Like this, I am not able to pass argnums=1
to jacrev
When replacing jacrev with jacfwd, the following error occurs ...
Ahh sorry that's my fault, I'll put up a patch to fix that today. In the meantime, if you want to use fwd mode, could you try this (basically recreating what jacfwd is doing under the hood):
def push_jvp(primal, basis):
_, jvp_out = jvp(torch.sum, (primal,), (basis,)) # this assumes we're only passing in a single tensor X. I can forward some code if you need to pass a tuple
return jvp_out
def jacfwd_hack(primal):
tangents = torch.eye(primal.numel()).reshape(-1, *primal.shape)
out = vmap(push_jvp, randomness="same", in_dims=(None, 0))(primal, tangents) # this is the vmap that was causing issues earlier
return out.reshape(*primal.shape, *out.shape[1:])
vmap(jacfwd_hack, randomness="same")(X)
jacrev ran the whole forward pass of the model, but then exploded in RAM (and I am with 350GB of RAM 😬, and yes I'm running on CPU)
If you do still want to run jacrev instead of jacfwd, we can try this:
gradient = functorch.experimental.chunk_vmap(jacrev(model_forward_func), chunks=2, randomness="same")(X)
and then playing with the chunks parameter. Basically what this parameter is doing is manually offering a time/memory tradeoff. Here it will chunk X
into 2 separate tensors, then run each of those through vmap separately and stack it back together. It will be slower but will use less peak memory
Does it make a difference to pass params to vmap?
In this case no. More generally, if you're trying to ever get the gradients with respect to the nn.Module
, you'll have to so a lot of times people get in the habit of always passing the parameters to the function
If you do still want to run jacrev instead of jacfwd, we can try this: gradient = functorch.experimental.chunk_vmap(jacrev(model_forward_func), chunks=2, randomness="same")(X) and then playing with the chunks parameter. Basically what this parameter is doing is manually offering a time/memory tradeoff. Here it will chunk X into 2 separate tensors, then run each of those through vmap separately and stack it back together. It will be slower but will use less peak memory
functorch.experimental.chunk_vmap
does not exist in the current version functorch==0.2.1
. Nonetheless, I imported all the code from this file to try this chunk_vmap
👀
Unfortunately it is still not working. The error is always the same and does not change it I set a larger number of chunks
RuntimeError: [enforce fail at alloc_cpu.cpp:66] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 2473901162496 bytes. Error code 12 (Cannot allocate memory)
Edit: If it helps, the code is breaking in here:
File ~/diffusers-interpret/venv/lib/python3.10/site-packages/functorch/_src/eager_transforms.py:555, in <genexpr>(.0)
553 total_numel = sum(tensor_numels)
554 diag_start_indices = (0, *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind())
--> 555 chunks = tuple(tensor.new_zeros(total_numel, tensor_numel)
556 for tensor, tensor_numel in zip(tensors, tensor_numels))
557 for chunk, diag_start_idx in zip(chunks, diag_start_indices):
558 chunk.diagonal(diag_start_idx).fill_(1)
Perhaps it might be better to upgrade to the latest version?
Also, doesn't chunk_vmap
only work for chunking over a batch of samples? So, if you're working on a single-sample right chunk_vmap
won't change anything will it? (Unless I'm missing something)
Perhaps it might be better to upgrade to the latest version?
I'm in the latest PyPi version :)
Also, doesn't
chunk_vmap
only work for chunking over a batch of samples? So, if you're working on a single-sample rightchunk_vmap
won't change anything will it? (Unless I'm missing something)
Yes, that's also what I understood :(
model_forward_func(X)
works fine, can't I reuse this computation for jacrev
?
Perhaps it might be better to upgrade to the latest version?
I'm in the latest PyPi version :)
Also, doesn't
chunk_vmap
only work for chunking over a batch of samples? So, if you're working on a single-sample rightchunk_vmap
won't change anything will it? (Unless I'm missing something)Yes, that's also what I understood :(
yes, but the latest version is built from source which is 0.3.0.
How many parameters are in your model? Because if it can't compute the gradients with a single sample it might just be an OOM issue.
How many parameters are in your model? Because if it can't compute the gradients with a single sample it might just be an OOM issue.
It's pretty huge, summing up all the parameters I get a total of 1,066,235,307 parameters.
Edit: This is the model btw
Is it possible to calculate gradients in the middle of the forward pass and aggregate at the end? 🤔
To check @AlphaBetaGamma96's intuition that it might just be an OOM issue, I know you're able to compute the forward pass but are you able to compute just gradients on your machine? (grad(model_forward_func)
)? If so and you're okay with it being slow, we might be able to chunk it from there (right now jacrev doesn't have a chunks parameter but we were also considering that)
To check @AlphaBetaGamma96's intuition that it might just be an OOM issue, I know you're able to compute the forward pass but are you able to compute just gradients on your machine? (
grad(model_forward_func)
)? If so and you're okay with it being slow, we might be able to chunk it from there (right now jacrev doesn't have a chunks parameter but we were also considering that)
It does not allow me:
RuntimeError: grad_and_value(f)(*args): Expected f(*args) to return a scalar Tensor, got tensor with 4 dims. Maybe you wanted to use the vjp or jacrev APIs instead?
However, running torch.autograd.grad
works fine, although it requires ~150GB RAM and takes ~1min for this model.
text = "a photo of an astronaut riding a horse on mars" X = get_input_embeddings(text) # output [1, embedding_dim]
Y = text2imagemodel(X) # shape [1, width, height]
You're trying to compute the derivative of the entire output w.r.t to the entire input right? From your comment above you've basically got a Jacobian of [width, height, embedding_dim]
. What is the size of these three values?
So grad
(from functorch) won't work as that expected a scalar, right? (hence the error message suggesting jacrev) But for torch.autograd.grad
you must be passing an argument like torch.autograd.grad(output, input, torch.ones_like(output))
in order to get it to work for non-scalar outputs? Can you share how you got torch.autograd.grad
to work?
You're trying to compute the derivative of the entire output w.r.t to the entire input right?
Exactly, the gradient of each pixel with respect to each word embedding
From your comment above you've basically got a Jacobian of [width, height, embedding_dim]. What is the size of these three values?
[512, 512, 768]
So grad (from functorch) won't work as that expected a scalar, right? (hence the error message suggesting jacrev) But for torch.autograd.grad you must be passing an argument like torch.autograd.grad(output, input, torch.ones_like(output)) in order to get it to work for non-scalar outputs? Can you share how you got torch.autograd.grad to work?
torch.autograd.grad accepts a tuple of tensors in output
, I just need to split the output Y
in a tuple of tensors with Y = tuple(torch.flatten(Y))
.
The only problem is that torch.autograd.grad
returns a single gradient with respect to the entire input: [batch, input_dim, embedding_dim], instead of being a tuple of gradients, one for each scalar
@AlphaBetaGamma96 @samdow thanks for your perseverance in this issue, really appreciate it 🙏 . It would be amazing if we get this to work!
@JoaoLages Sorry for the delay in response, do you have an E2E repro that you could share? We're trying to understand if it's going to be better to recommend using a variation of torch.autograd.grad
or composition of chunks_vmap
and vjp
@JoaoLages Sorry for the delay in response, do you have an E2E repro that you could share? We're trying to understand if it's going to be better to recommend using a variation of
torch.autograd.grad
or composition ofchunks_vmap
andvjp
@samdow sorry for the late reply. I created a Google Colab to reproduce the error. This is exploding the GPU when it wasn't really expected for a 64x64 image... With normal torch.autograd.grad
the gradient calculation is fast and does not use so much GPU
According to the README, we are able to calculate per-sample-gradients with functorch.
But what if we want to get multiple gradients for a single sample? For example, imagine that we are calculating multiple losses.
We can split each loss calculation as a different sample, but that implementation is inefficient, especially when the forward pass is expensive. Can we at least re-use forward computations?