pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Multiple gradient calculation for single sample #1010

Closed JoaoLages closed 1 year ago

JoaoLages commented 1 year ago

According to the README, we are able to calculate per-sample-gradients with functorch.

But what if we want to get multiple gradients for a single sample? For example, imagine that we are calculating multiple losses.

We can split each loss calculation as a different sample, but that implementation is inefficient, especially when the forward pass is expensive. Can we at least re-use forward computations?

AlphaBetaGamma96 commented 1 year ago

Hi @JoaoLages, if you want to compute multiple losses from a single sample you should be able to do this via forward-mode AD. I've written a brief example below,

import torch
from torch import nn

import functorch
from functorch import jacfwd, make_functional, vmap

class model(nn.Module):

  def __init__(self):
    super(model, self).__init__()
    self.fc=nn.Linear(2, 4)  

  def forward(self, x):
    y=self.fc(x)
    return y, y.pow(2).sum(-1).sqrt() #our two losses (output, and Euclidean norm)

B=100 #batch size
N=2   #input size

x = torch.randn(B, N)
func=model()

fnet, params = make_functional(func)

dy_dx, dnorm_dx = vmap(jacfwd(fnet, argnums=1), in_dims=(None, 0))(params, x)
print(dy_dx.shape, dnorm_dx.shape) #returns torch.Size([100, 4, 2]) torch.Size([100, 2])

As can be seen from the example both the jacobian of the nn.Linear layer and the gradient of its norm are computed in parallel, and it's all vectorized via vmap!

(For completeness this issue branches from a forum post, found here)

JoaoLages commented 1 year ago

In my case the model is very heavy and I can only call model.forward with a single sample. I would like to first compute Y = model.forward(X) and then compute the gradient of the multiple scalar losses that are in Y, do you see?

AlphaBetaGamma96 commented 1 year ago

In my case the model is very heavy and I can only call model.forward with a single sample. I would like to first compute Y = model.forward(X) and then compute the gradient of the multiple scalar losses that are in Y, do you see?

I don't see why that'd be an issue, if you define a function that evaluates your model and then computes all scalar loss and returns then as outputs you should be able to re-use the example I wrote above (without using vmap of course).

Also, try and attach a minimal reproducible example so people can debug the problem!

JoaoLages commented 1 year ago

Imagine a text to image model, which exactly my use case.

import torch
from torch import nn

import functorch
from functorch import jacfwd, make_functional, vmap

text = "a photo of an astronaut riding a horse on mars"
X = get_input_embeddings(text) # output [1, embedding_dim]

Y = text2imagemodel(X) # shape [1, width, height]

How do I calculate the gradient of each pixel in Y with respect to the input embeddings X?

AlphaBetaGamma96 commented 1 year ago

Have you tried using jacrev or jacfwd?

Something like,

fnet, params = make_functional(text2imagemodel) #or the buffers version if your model uses buffers
gradient = vmap(jacrev(fnet, argnums=0), in_dims=(None, 0))(params, X)
JoaoLages commented 1 year ago

Have you tried using jacrev or jacfwd?

Something like,

fnet, params = make_functional(text2imagemodel) #or the buffers version if your model uses buffers
gradient = vmap(jacrev(fnet, argnums=0), in_dims=(None, 0))(params, X)

Thanks for your help! 🙏 Will try to have a look at it again tomorrow and will come back to this issue :)

JoaoLages commented 1 year ago

Have you tried using jacrev or jacfwd?

Something like,

fnet, params = make_functional(text2imagemodel) #or the buffers version if your model uses buffers
gradient = vmap(jacrev(fnet, argnums=0), in_dims=(None, 0))(params, X)

jacrev ran the whole forward pass of the model, but then exploded in RAM (and I am with 350GB of RAM 😬, and yes I'm running on CPU) When replacing jacrev with jacfwd, the following error occurs:

RuntimeError: vmap: called random operation while in randomness error mode. Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap

Even though I already had randomness='same' in my call gradient = vmap(jacrev(model_forward_func), randomness="same")(X.unsqueeze(0))

At the end of the day I just wanted torch.autograd.grad to return me the gradients per scalar and not their sum 😩

AlphaBetaGamma96 commented 1 year ago

Have you tried using jacrev or jacfwd? Something like,

fnet, params = make_functional(text2imagemodel) #or the buffers version if your model uses buffers
gradient = vmap(jacrev(fnet, argnums=0), in_dims=(None, 0))(params, X)

jacrev ran the whole forward pass of the model, but then exploded in RAM (and I am with 350GB of RAM grimacing, and yes I'm running on CPU) When replacing jacrev with jacfwd, the following error occurs:

RuntimeError: vmap: called random operation while in randomness error mode. Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap

Even though I already had randomness='same' in my call gradient = vmap(jacrev(model_forward_func), randomness="same")(X.unsqueeze(0))

At the end of the day I just wanted torch.autograd.grad to return me the gradients per scalar and not their sum weary

It should be argnums=1 not 0,

gradient = vmap(jacrev(fnet, argnums=1), in_dims=(None, 0))(params, X)

as that's for the parameters and not the input. Try that and see if it works.

JoaoLages commented 1 year ago

, argnums=1

Does it make a difference to pass params to vmap? Because my model is composed of more than 1 nn.Module, so I don't pass the params to vmap, they are inside my model_forward_func function:

gradient = vmap(jacrev(model_forward_func), randomness="same")(X.unsqueeze(0))

Like this, I am not able to pass argnums=1 to jacrev

samdow commented 1 year ago

When replacing jacrev with jacfwd, the following error occurs ...

Ahh sorry that's my fault, I'll put up a patch to fix that today. In the meantime, if you want to use fwd mode, could you try this (basically recreating what jacfwd is doing under the hood):

def push_jvp(primal, basis):
  _, jvp_out = jvp(torch.sum, (primal,), (basis,))  # this assumes we're only passing in a single tensor X. I can forward some code if you need to pass a tuple
  return jvp_out

def jacfwd_hack(primal):
  tangents = torch.eye(primal.numel()).reshape(-1, *primal.shape)
    out = vmap(push_jvp, randomness="same", in_dims=(None, 0))(primal, tangents) # this is the vmap that was causing issues earlier
  return out.reshape(*primal.shape, *out.shape[1:])

vmap(jacfwd_hack, randomness="same")(X)

jacrev ran the whole forward pass of the model, but then exploded in RAM (and I am with 350GB of RAM 😬, and yes I'm running on CPU)

If you do still want to run jacrev instead of jacfwd, we can try this: gradient = functorch.experimental.chunk_vmap(jacrev(model_forward_func), chunks=2, randomness="same")(X) and then playing with the chunks parameter. Basically what this parameter is doing is manually offering a time/memory tradeoff. Here it will chunk X into 2 separate tensors, then run each of those through vmap separately and stack it back together. It will be slower but will use less peak memory

Does it make a difference to pass params to vmap?

In this case no. More generally, if you're trying to ever get the gradients with respect to the nn.Module, you'll have to so a lot of times people get in the habit of always passing the parameters to the function

JoaoLages commented 1 year ago

If you do still want to run jacrev instead of jacfwd, we can try this: gradient = functorch.experimental.chunk_vmap(jacrev(model_forward_func), chunks=2, randomness="same")(X) and then playing with the chunks parameter. Basically what this parameter is doing is manually offering a time/memory tradeoff. Here it will chunk X into 2 separate tensors, then run each of those through vmap separately and stack it back together. It will be slower but will use less peak memory

functorch.experimental.chunk_vmap does not exist in the current version functorch==0.2.1. Nonetheless, I imported all the code from this file to try this chunk_vmap 👀

Unfortunately it is still not working. The error is always the same and does not change it I set a larger number of chunks

RuntimeError: [enforce fail at alloc_cpu.cpp:66] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 2473901162496 bytes. Error code 12 (Cannot allocate memory)

Edit: If it helps, the code is breaking in here:

File ~/diffusers-interpret/venv/lib/python3.10/site-packages/functorch/_src/eager_transforms.py:555, in <genexpr>(.0)
    553 total_numel = sum(tensor_numels)
    554 diag_start_indices = (0, *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind())
--> 555 chunks = tuple(tensor.new_zeros(total_numel, tensor_numel)
    556                for tensor, tensor_numel in zip(tensors, tensor_numels))
    557 for chunk, diag_start_idx in zip(chunks, diag_start_indices):
    558     chunk.diagonal(diag_start_idx).fill_(1)
AlphaBetaGamma96 commented 1 year ago

Perhaps it might be better to upgrade to the latest version?

Also, doesn't chunk_vmap only work for chunking over a batch of samples? So, if you're working on a single-sample right chunk_vmap won't change anything will it? (Unless I'm missing something)

JoaoLages commented 1 year ago

Perhaps it might be better to upgrade to the latest version?

I'm in the latest PyPi version :)

Also, doesn't chunk_vmap only work for chunking over a batch of samples? So, if you're working on a single-sample right chunk_vmap won't change anything will it? (Unless I'm missing something)

Yes, that's also what I understood :(

JoaoLages commented 1 year ago

model_forward_func(X) works fine, can't I reuse this computation for jacrev?

AlphaBetaGamma96 commented 1 year ago

Perhaps it might be better to upgrade to the latest version?

I'm in the latest PyPi version :)

Also, doesn't chunk_vmap only work for chunking over a batch of samples? So, if you're working on a single-sample right chunk_vmap won't change anything will it? (Unless I'm missing something)

Yes, that's also what I understood :(

yes, but the latest version is built from source which is 0.3.0.

How many parameters are in your model? Because if it can't compute the gradients with a single sample it might just be an OOM issue.

JoaoLages commented 1 year ago

How many parameters are in your model? Because if it can't compute the gradients with a single sample it might just be an OOM issue.

It's pretty huge, summing up all the parameters I get a total of 1,066,235,307 parameters.

Edit: This is the model btw

JoaoLages commented 1 year ago

Is it possible to calculate gradients in the middle of the forward pass and aggregate at the end? 🤔

samdow commented 1 year ago

To check @AlphaBetaGamma96's intuition that it might just be an OOM issue, I know you're able to compute the forward pass but are you able to compute just gradients on your machine? (grad(model_forward_func))? If so and you're okay with it being slow, we might be able to chunk it from there (right now jacrev doesn't have a chunks parameter but we were also considering that)

JoaoLages commented 1 year ago

To check @AlphaBetaGamma96's intuition that it might just be an OOM issue, I know you're able to compute the forward pass but are you able to compute just gradients on your machine? (grad(model_forward_func))? If so and you're okay with it being slow, we might be able to chunk it from there (right now jacrev doesn't have a chunks parameter but we were also considering that)

It does not allow me: RuntimeError: grad_and_value(f)(*args): Expected f(*args) to return a scalar Tensor, got tensor with 4 dims. Maybe you wanted to use the vjp or jacrev APIs instead?

However, running torch.autograd.grad works fine, although it requires ~150GB RAM and takes ~1min for this model.

AlphaBetaGamma96 commented 1 year ago

text = "a photo of an astronaut riding a horse on mars" X = get_input_embeddings(text) # output [1, embedding_dim]

Y = text2imagemodel(X) # shape [1, width, height]

You're trying to compute the derivative of the entire output w.r.t to the entire input right? From your comment above you've basically got a Jacobian of [width, height, embedding_dim]. What is the size of these three values?

So grad (from functorch) won't work as that expected a scalar, right? (hence the error message suggesting jacrev) But for torch.autograd.grad you must be passing an argument like torch.autograd.grad(output, input, torch.ones_like(output)) in order to get it to work for non-scalar outputs? Can you share how you got torch.autograd.grad to work?

JoaoLages commented 1 year ago

You're trying to compute the derivative of the entire output w.r.t to the entire input right?

Exactly, the gradient of each pixel with respect to each word embedding

From your comment above you've basically got a Jacobian of [width, height, embedding_dim]. What is the size of these three values?

[512, 512, 768]

So grad (from functorch) won't work as that expected a scalar, right? (hence the error message suggesting jacrev) But for torch.autograd.grad you must be passing an argument like torch.autograd.grad(output, input, torch.ones_like(output)) in order to get it to work for non-scalar outputs? Can you share how you got torch.autograd.grad to work?

torch.autograd.grad accepts a tuple of tensors in output, I just need to split the output Y in a tuple of tensors with Y = tuple(torch.flatten(Y)). The only problem is that torch.autograd.grad returns a single gradient with respect to the entire input: [batch, input_dim, embedding_dim], instead of being a tuple of gradients, one for each scalar

@AlphaBetaGamma96 @samdow thanks for your perseverance in this issue, really appreciate it 🙏 . It would be amazing if we get this to work!

samdow commented 1 year ago

@JoaoLages Sorry for the delay in response, do you have an E2E repro that you could share? We're trying to understand if it's going to be better to recommend using a variation of torch.autograd.grad or composition of chunks_vmap and vjp

JoaoLages commented 1 year ago

@JoaoLages Sorry for the delay in response, do you have an E2E repro that you could share? We're trying to understand if it's going to be better to recommend using a variation of torch.autograd.grad or composition of chunks_vmap and vjp

@samdow sorry for the late reply. I created a Google Colab to reproduce the error. This is exploding the GPU when it wasn't really expected for a 64x64 image... With normal torch.autograd.grad the gradient calculation is fast and does not use so much GPU