aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
432 stars 61 forks source link

[Discussion] Last-layer Laplace for img2img problem #118

Open wiseodd opened 1 year ago

wiseodd commented 1 year ago

@AlexImmer, @runame, @edaxberger: As you know, I'm currently working on last-layer Laplace for img2img tasks, e.g. autoencoder, image segmentation. We can't use the current implementation in this library mainly due the fact that we hard-code the last-layer Jacobian to be the fully-connected Jacobian---see #111 for example. Note: GGN computation using BackPACK & ASDL doesn't seem to pose any problem (#111 for ASDL, below for BackPACK).

So, my current thinking is to simply generalizing the last_layer_jacobians in laplace/curvature/curvature.py using functorch, see the predict function below. I also propose to only support diagonal LLLA since it's too costly otherwise.

Let me know your thoughts and if I missed anything. Feel free to try out the self-contained script below.

import torch
import torch.nn.functional as F
import torchvision as tv
import torchvision.transforms as transforms
from torch import nn, optim
from functorch import jacrev
from backpack import backpack, extend
from backpack.extensions import DiagGGNExact

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose([transforms.ToTensor()])
train_batch_size = 32
test_batch_size = 10

# TODO Replace root path
trainset = tv.datasets.CIFAR10(
    root='~/Datasets', train=True, transform=transform, download=False
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True)

# TODO Replace root path
testset = tv.datasets.CIFAR10(
    root='~/Datasets', train=False, transform=transform, download=False
)
testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False)

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 100, kernel_size=5),
            nn.Sigmoid(),
        )
        self.last_layer = nn.Sequential(
            nn.ConvTranspose2d(100, 3, kernel_size=5, bias=False),
            nn.Flatten(1)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        return self.last_layer(x)

model = Model().to(DEVICE)

for p in model.feature_extractor.parameters():
    p.requires_grad = False

lastlayer = extend(model.last_layer)
lossfunc = extend(nn.MSELoss(reduction='sum'))

for x, _ in trainloader:
    x = x.to(DEVICE)

    # (n_data, n_channel*width*height)
    reconstruction = lastlayer(model.feature_extractor(x))

    loss = lossfunc(reconstruction, x.flatten(1))
    with backpack(DiagGGNExact()):
        loss.backward()

    # (n_ll_params,)
    GGN = model.last_layer[0].weight.diag_ggn_exact

# Covariance
prec0 = 1
Sigma = torch.linalg.inv(prec0 + GGN)

@torch.no_grad()
def predict(x):
    phi = model.feature_extractor(x)

    # MAP prediction
    mean_pred = model.last_layer(phi).reshape(x.shape)

    # Variance
    def f(feat, w):
        """ w is vectorized """
        SHAPE = (100, 3, 5, 5)
        return F.conv_transpose2d(feat, w.reshape(SHAPE))

    jac = jacrev(f, argnums=1)

    # (n_data, n_channel, width, height, n_params)
    J_pred = jac(phi, model.last_layer[0].weight.flatten())

    # (n_data, n_channel, width, height)
    var_pred = torch.einsum('nabci,i,nabci->nabc', J_pred, Sigma, J_pred)

    return mean_pred.cpu().numpy(), var_pred.cpu().numpy()

for x, _ in testloader:
    predict(x.to(DEVICE))
    break
wiseodd commented 1 year ago

Beyond CIFAR images, though, the GGN computation will also be an issue. E.g. in ImageNet the output dim is 224*224*3=150528, much larger than 3024 of CIFAR. I talked Felix about this and one solution is to exploit the per-pixel nature of the loss and compute the minibatch-GGN in chunk in terms of output dimension, see example for MNIST below.

Thoughts?

from backpack import backpack, extend
from backpack.custom_module.slicing import Slicing
from backpack.extensions import DiagGGNExact
​
lastlayer = extend(model.last_layer)
lossfunc = extend(nn.MSELoss(reduction='sum'))
​​
chunked_ggn = torch.zeros_like(model.last_layer[0].weight)
​
for x, _ in trainloader:
    x = x.to(DEVICE)
​
    # [N, 784]
    reconstruction = lastlayer(model.feature_extractor(x))
​​
    for i in range(28):
        slicing = (slice(None), slice(i * 28, (i + 1) * 28))
        slicing_module = extend(Slicing(slicing))
​
        sliced_reconstruction = slicing_module(reconstruction)
        sliced_loss = lossfunc(sliced_reconstruction, x.flatten(1)[slicing])
​
        with backpack(DiagGGNExact(), retain_graph=True):
            sliced_loss.backward(retain_graph=True)
            chunked_ggn += model.last_layer[0].weight.diag_ggn_exact
wiseodd commented 1 year ago

For predictions/reconstructions, my proposal is to use https://github.com/f-dangel/unfoldNd. Using this, then conv_transpose2d is just a matrix multiplication under the original weights/filters, implying that we can easily obtain $p(f(x))$.

import unfoldNd

prec0 = 1

# Laplace cov
diag_Sigma = 1/(diag_GGN + prec0)
diag_Sigma = diag_Sigma.transpose(0, 1).flatten(1)

# diag_Sigma.shape should be (c_out, c_in*k*k
# )
assert len(diag_Sigma.shape) == 2 and diag_Sigma.shape == (1, 100*3*3)

# Following the last layer of the model
unfold_transpose = unfoldNd.UnfoldTransposeNd(
    kernel_size=3, dilation=1, padding=1, stride=3
)

@torch.no_grad()
def reconstruct(x):
    phi = model.feature_extractor(x)

    # MAP prediction
    mean_pred = model.last_layer(phi).reshape(x.shape)

    # Variance
    J_pred = unfold_transpose(phi)
    var_pred = torch.einsum('bij,ki,bij->bkj', J_pred, diag_Sigma, J_pred).reshape(mean_pred.shape)

    return mean_pred.cpu().numpy(), var_pred.cpu().numpy()

x_recons = []

for x, _ in testloader:
    x = x.cuda()
    x_recons.append(reconstruct(x))
wiseodd commented 1 year ago

Full, self-contained prototype here: https://gist.github.com/wiseodd/b8d57fa029f876e00b336b7b3b5052bd

JRopes commented 5 months ago

Hello @wiseodd , have there been any updates on this topic over the last year? I am currently working on Laplace approximations for segmentation tasks and would be very interested. Thank you!

wiseodd commented 5 months ago

Unfortunately, there's no update on this. Partly because the loss function usually used in image problems (BCELoss) is not supported by the Hessian backends, and partly because my research agenda is far away from computer vision/graphics.

In any case, I can point you to a good direction:

In any case, I hope the references above and the snippets in the previous posts are useful for you.

wiseodd commented 3 months ago

This issue should be easier to solve once #145 is merged. Will work on this after the release of milestone 0.2.