f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
555 stars 55 forks source link

Gradient of tensor output #266

Closed fabian-sp closed 1 year ago

fabian-sp commented 1 year ago

Heya,

I have a scenario where I have a multidimensional model output (lets say dimension m) and hence what goes into the final loss function would have shape [b,m] (b being the batch size). Is there a way to compute the gradients of every single element in that tensor (lets call it U) with one backwards pass?

Lets say m=1. Then I got this to work with just computing

a = U.sum()

with backpack(BatchGrad()):
    a.backward()

For m>1, my idea was to first rehape U into [b*m,1] and then do the same trick, but somehow this does not work. The shape of grad_batch is still always b, shape of grad.

If you can help me with this, that would be awesome :)

Cheers, Fabian

A simple code example:

import torch
from backpack import extend
from backpack import backpack
from backpack.extensions import BatchGrad

m = 3
p = 6
b = 5

X = torch.randn(b, p)
W = torch.nn.Linear(p, m)

W = extend(W)
U = W(X)
l = U.sum()

with backpack(BatchGrad()):
    l.backward()
f-dangel commented 1 year ago

Hi, thanks for your question.

Let me try to rephrase it to see if I got what you mean. You want to compute the derivatives ∂U / ∂W (a Jacobian matrix), where U is a batched vector (not scalar) and W is the parameter of some layer.


Here is a comparison between 1) autograd (m * b for loops), 2) BackPACK's BatchGrad (m for loops), and 3) BackPACK's core functionality (no for loops). It computes the quantity you requested with each approach and checks for equal results.

In my opinion, you are best off choosing between options 1) and 2). Option 3) requires writing a second-order extension because you need to pass back separate quantities (Jacobians, and not vector-Jacobian products as done by PyTorch) through the graph.

import itertools

import torch

from backpack import backpack, extend
from backpack.core.derivatives.linear import LinearDerivatives
from backpack.extensions import BatchGrad

m = 3
p = 6
b = 5

# Using autograd (for loop over batch and output dimension)
torch.manual_seed(0)
X = torch.randn(b, p)

linear = torch.nn.Linear(p, m, bias=False)
U = linear(X)

autograd_dU_dW = torch.zeros((b, m) + linear.weight.shape)

for b_idx, m_idx in itertools.product(range(b), range(m)):
    loss = U[b_idx, m_idx]
    (autograd_dU_dW[b_idx, m_idx],) = torch.autograd.grad(
        loss, linear.weight, retain_graph=True
    )

# Using BackPACK's BatchGrad (for loop over output dimension)
torch.manual_seed(0)
X = torch.randn(b, p)

linear = extend(torch.nn.Linear(p, m, bias=False))
U = linear(X)
losses = U.sum(0)

BatchGrad_dU_dW = torch.zeros((b, m) + linear.weight.shape)

for m_idx in range(m):
    loss = losses[m_idx]
    with backpack(BatchGrad(), retain_graph=True):
        loss.backward(retain_graph=True)
    BatchGrad_dU_dW[:, m_idx] = linear.weight.grad_batch

assert torch.allclose(autograd_dU_dW, BatchGrad_dU_dW)

# Using BackPACK's core (no for loops, but you would have to write a new extension)
# NOTE Not recommended, only read this if you know what you are doing.
torch.manual_seed(0)
X = torch.randn(b, p)

linear = extend(torch.nn.Linear(p, m, bias=False))
U = linear(X)

identity = torch.stack(b * [torch.eye(m, m)]).transpose(0, 1)
backpack_core_dU_dW = (
    LinearDerivatives()
    .param_mjp("weight", linear, None, None, identity, sum_batch=False)
    .reshape(m, b, *linear.weight.shape)
    .transpose(0, 1)
)

assert torch.allclose(autograd_dU_dW, backpack_core_dU_dW)

Best, Felix

fKunstner commented 1 year ago

Hi Fabian,

Backpack doesn't support batched Jacobians. The trick doesn't work because BatchGrad() use the first dimension of things that are passed through the network as the batch dimension.

There might be a way to bodge something out, or better options depending on your end goal. Can you give us some details on your setup? What kind of models are you interested in (what layers would go in there), what do you want to do with those batch jacobians (NTK stuff?)

fabian-sp commented 1 year ago

Hi, thank you for the quick response and your help! I need the individual elements for exploring an optimization method, so I have no particular model or layer in mind.

In a dream world, I would be able to compute those quantities roughly in the same time as a single .backward() step.

I guess solution 2) would be good enough for models with a small output dimension. However, I will check out option 3) as well and try to understand how this would work. @f-dangel Is there any particular reason why you actively not recommend to write this extension? Do you think it will be too hard to do? Can you explain how your hacky solution works or point me to the right place in your docs to understand this?

Thanks again and cheers, Fabian

f-dangel commented 1 year ago

Hi Fabian,

I mainly did not recommend implementing this extension because of the amount of work.

After some thinking, I found another trick to compute the Jacoian you requested, re-using one of BackPACK's extension. It relies on the idea that the generalized Gauss-Newton's/Fisher's matrix square root for MSELoss is the Jacobian (more details between Eqs (2) and (3)). Here is how it works:

You need to feed U through a MSELoss(reduction='sum') layer with arbitrary labels, then call backward() on the result and use BackPACK's SqrtGGNExact extension. Each parameter W will then have a W.sqrt_ggn_exact attribute that contains the Jacobian ∂U / ∂W, up to a transposition and scaling by 1 / sqrt(2).

Here is the described approach in code. You have to add it to the above snippet to be able to run it:

from math import sqrt
from backpack.extensions import SqrtGGNExact

# Using BackPACK's SqrtGGNExact extension
torch.manual_seed(0)
X = torch.randn(b, p)

linear = extend(torch.nn.Linear(p, m, bias=False))
loss_fn = extend(torch.nn.MSELoss(reduction="sum"))

U = linear(X)
U_labels = torch.zeros_like(U)  # can contain arbitrary values
loss = loss_fn(U, U_labels)

with backpack(SqrtGGNExact()):
    loss.backward()

backpack_SqrtGGNExact_dU_dW = linear.weight.sqrt_ggn_exact.transpose(0, 1) / sqrt(2)

assert torch.allclose(autograd_dU_dW, backpack_SqrtGGNExact_dU_dW)

Let me know if this is helpful, Felix


Btw, this also works if U has a more complicated shape, e.g. [b, m, n]. In this case, you need to add a Flatten layer before the MSELoss.

f-dangel commented 1 year ago

I added an example to the documentation.

fabian-sp commented 1 year ago

Hallo Felix,

great, thank you for your effort!