f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
549 stars 55 forks source link

Optimizing the locations of the Jacobians #300

Open Ludvins opened 1 year ago

Ludvins commented 1 year ago

Hi!

Thanks for the contributions!! I am facing a weird sutiation in which I am computing the jacobians of a model w.r.t to the parameters (which are fixed and cannot be tuned) and I want to optimize the input locations of those jacobians, lets say the x variables. Two different things are happening to me when trying to to this:

  1. The locations retain a .grad value when using the backward call to compute the jacobians w.r.t to the parameters using BackPack.
  2. The backward call of my loss function does not alter the .grad value of the inputs used to compute the jacobians,

Is there any way I can fix this?

f-dangel commented 1 year ago

Hi @Ludvins,

do you have a small code snippet that explains what you are trying to do? This would be valuable to figure out a solution.

Best, Felix

Ludvins commented 1 year ago

Sure, this is a function from the Laplace library that computes the Jacobian of a model w.r.t to the parameters given an input source:

def jacobians(self, x):
    """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\)
    using backpack's BatchGrad per output dimension.
    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.
    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    model = extend(self.model)
    to_stack = []
    for i in range(model.output_size):
        model.zero_grad()
        out = model(x)
        with backpack(BatchGrad()):
            if model.output_size > 1:
                out[:, i].sum().backward()
            else:
                out.sum().backward()
            to_cat = []
            for param in model.parameters():
                to_cat.append(param.grad_batch.detach().reshape(x.shape[0], -1))
                delattr(param, 'grad_batch')
            Jk = torch.cat(to_cat, dim=1)
            if self.subnetwork_indices is not None:
                Jk = Jk[:, self.subnetwork_indices]
        to_stack.append(Jk)
        if i == 0:
            f = out.detach()

    model.zero_grad()
    CTX.remove_hooks()
    _cleanup(model)
    if model.output_size > 1:
        return torch.stack(to_stack, dim=2).transpose(1, 2), f
    else:
        return Jk.unsqueeze(-1).transpose(1, 2), f

I am using the output of this function evaluated on a set of trainable locations x to optimize x. However, the above piece of code has two problems:

  1. The backward pass creates a grad on x that is not desired.
  2. Jk does not require grad, so optimizing x is not possible.

I have made a work around using create_graph = True in the backwards call of the above code. However, this is probably not a good workaround as I am getting the following warning UserWarning: Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak.. I am leaving the edited code below to highlight my changes

def jacobians(self, x):
    """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\)
    using backpack's BatchGrad per output dimension.

    Parameters
    ----------
    x : torch.Tensor
        input data `(batch, input_shape)` on compatible device with model.

    Returns
    -------
    Js : torch.Tensor
        Jacobians `(batch, parameters, outputs)`
    f : torch.Tensor
        output function `(batch, outputs)`
    """
    model = extend(self.model)
    to_stack = []
    for i in range(model.output_size):
        model.zero_grad()
        out = model(x)
        with backpack(BatchGrad(), retain_graph=True):
            if model.output_size > 1:
                out[:, i].sum().backward(create_graph=True)
            else:
                out.sum().backward(create_graph=True)
            to_cat = []
            for param in model.parameters():
                to_cat.append(param.grad_batch.reshape(x.shape[0], -1))
                delattr(param, 'grad_batch')
            Jk = torch.cat(to_cat, dim=1)
            if self.subnetwork_indices is not None:
                Jk = Jk[:, self.subnetwork_indices]
        to_stack.append(Jk)
        if i == 0:
            f = out

    model.zero_grad()
    x.grad = None
    CTX.remove_hooks()
    _cleanup(model)
    if model.output_size > 1:
        return torch.stack(to_stack, dim=2).transpose(1, 2), f
    else:
        return Jk.unsqueeze(-1).transpose(1, 2), f
fKunstner commented 1 year ago

Hi @Ludvins,

From what I understand, your setup looks like this. You have a function f ∶ Rᵃ x Rᵇ → Rᶜ with inputs x ∈ Rᵃ, 𝜃 ∈ Rᵇ. You want to optimize (with respect to x) another function that is defined in terms of the Jacobian of f(x, 𝜃) with respect to 𝜃. For example your function might be h(x) = ‖J_𝜃 f(x, 𝜃)‖, and you're trying to compute ∇h(x).

A few questions;

Ludvins commented 1 year ago

Hi @fKunstner

The framework is similar to what you pointed in your first question, x is a batch of a dimensional points and I am trying to get n x b Jacobians, one per element in x and parameter. This Jacobians are then used to create the NTK kernel matrix, that is J_𝜃 f(x, 𝜃) J_𝜃 f(x, 𝜃)^T which is a matrix n x n. The minimization function then uses this matrix.

I will take a look at jvp thank you so much.

fKunstner commented 1 year ago

Got it, assuming the other difference is that c = 1.

If you only want to use this matrix by doing matrix-vector products, ie computing v.T @ J.T @ J @ v for some v, the Jacobian-vector product might be faster. If you need it multiple times, computing the matrix once might be faster.

Re: how to do it, the sample code at the end of this post might help.

! Caution: I haven't checked that the return values are correct. Check that the values match some other implementations, see for examples the checks in the Gradient of backpropagated quantities example


The locations retain a .grad value when using the backward call to compute the jacobians w.r.t to the parameters using BackPack. The backward pass creates a grad on x that is not desired.

Interesting, I've noticed this. x needs to have requires_grad to be able to compute derivatives wrt x later on, but we can't not compute the gradient of x and have Backpack work. The work-around in the code below is to save the x.grad before computation, overwrite it, and replace it. Not elegant and it might break higher-order derivatives wrt x but for the gradient it should work.

For the technical info @f-dangel , trying to use f(x).sum().backwards(inputs=list(f.parameters()), which would be the way to avoid computing the gradients wrt x, the backward hook does not get called. I'm wondering if this is a Pytorch issue, as the docs say that the hook gets called if we need the gradient wrt the inputs (and the inputs have requires_grad but have it locally disabled because it is not in the inputs list). We're in this weird state where we don't need the gradient wrt to the inputs, but we still need the hook to be called. Why this is not an issue in normal operations is presumably because if x does not have requires_grad/is a leaf then everything is fine? This is weird.

The backward call of my loss function does not alter the .grad value of the inputs used to compute the jacobians,

You should be able to get the gradient through dh_dx = torch.autograd.grad(outputs=h, inputs=x) where h is the loss computed from the Jacobian computed with Backpack. See code below.

I have made a work around using create_graph = True in the backwards call of the above code. However, this is probably not a good workaround as I am getting the following warning UserWarning: Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak.

This is the correct workaround AFAIK. It is a warning because you need to be careful with de-allocating memory (and Noneing variables, If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak.


from typing import Iterable

import torch

from backpack import backpack, extend
from backpack.extensions import BatchGrad

def individual_grads_backpropable(x, f: torch.nn.Module):
    """Individual gradients of f(x_i), with respect to the parameters of  f,
    with backprop wrt x enabled.

    `x` is assumed to be a `n x d` tensor
    `f` is assumed to be a function `f : n x d -> n`
    """
    x.requires_grad = True

    # If x had a gradient, copy it to not change it
    x_grad_copy = None
    if x.grad is not None:
        x_grad_copy = x.grad.copy()
        x.grad = None

    with backpack(BatchGrad(), retain_graph=False):
        f(x).sum().backward(create_graph=True)

    # Put back old copy
    if x_grad_copy is not None:
        x.grad = None
        x.grad = x_grad_copy

    return [w.grad_batch for w in f.parameters()]

def batched_parameters_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor:
    r"""Convert parameters to one matrix. First dimension of each tensor is batch.
    See https://pytorch.org/docs/stable/_modules/torch/nn/utils/convert_parameters.html#parameters_to_vector
    """
    return torch.cat([p.view(p.shape[0], -1) for p in parameters], dim=1)

##
# Create the model and the data

torch.manual_seed(0)
batch_size, dim = 2, 3
x = torch.randn((batch_size, dim))
x.requires_grad = True
f = torch.nn.Linear(dim, 1)

##
# Run Backpack to compute the individual gradients, reshape into `n x d` matrix
extend(f)
Js = individual_grads_backpropable(x, f)
J = batched_parameters_to_vector(Js)

##
# Compute loss using jacobians and get gradient wrt x
v = torch.randn(4, 1)
h = v.T @ J.T @ J @ v

dh_dx = torch.autograd.grad(outputs=h, inputs=x)
print(dh_dx)

##
# Cleanup to avoid memory leaks
for w in f.parameters():
    w.grad = None
    w.grad_batch = None