aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
450 stars 70 forks source link

Regression covariance is only diagonal, with the same value across it #100

Closed ArturPrzybysz closed 4 months ago

ArturPrzybysz commented 2 years ago

I have created a last layer BNN with your package. I used "kron" and "diag" hessian structures in regression task. However, just as stated in the title, the covariance matrix diagonal has the same, single value.

Is this expected behavior?

I could provide a minimal code example if it is not expected and you suspect an error in implementation.

aleximmer commented 2 years ago

Hi Artur, thanks for pointing this issue out. I would be great if you could provide a minimal example to reproduce this so we can look into it. Mathematically this should not happen, or at least only in very special cases.

ArturPrzybysz commented 2 years ago

@AlexImmer Thank you for the response! The example I provide is not minimal, but pictures the situation well enough. It is optional to train the MAP model from scratch, I also provide a pretrained model state_dict to fit Laplace Approximation to.

The code is here: https://github.com/ArturPrzybysz/LaplaceConvDemo

They starting file is src/main.py , where the predictions are visualized, also the diagonals of the covariance matrix is printed and asserted to consist of a single value.

wiseodd commented 2 years ago

Hi Artur, it seems that there is indeed an issue with the glm predictive. When I change this line in your main.py:

pred = la_model.la(X)

into

pred = la_model.la(X, pred_type='nn', n_samples=10)

the variances don't have the same values. We'll investigate this further, but for now you can use the nn predictive which should work just fine (make sure to increase n_samples, though).

ArturPrzybysz commented 2 years ago

Sweet! Thank you for the help.

runame commented 2 years ago

@ArturPrzybysz One other thing I randomly noticed when briefly looking at your code was that the tuning of the prior precision is probably not working as you intended: you pass the validation loader, but keep the default method='marglik' argument, so the validation loader will not be used at all. If you want to use it, you will have to set method='CV' and also change the loss argument, which defaults to the cross entropy loss, whereas you probably want to use something like the MSE loss for regression. We will add additional checks to avoid this unintended behavior. Moreover, you should also specify the pred_type and link_approx arguments, to make them consistent with the settings you are using for prediction later on.

ArturPrzybysz commented 2 years ago

@runame Wow, thank you!

wiseodd commented 4 months ago

Revisiting this issue using the attached quick script.

This problem happens in last-layer Laplace (all-layer is fine), for any Hessian structures, with the following backends:

from laplace import Laplace, ParametricLaplace
from laplace.curvature import (
    CurvlinopsGGN,
    CurvlinopsEF,
    CurvlinopsHessian,
    AsdlGGN,
    AsdlEF,
    AsdlHessian,
    BackPackGGN,
    BackPackEF,
)
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

model = nn.Sequential(nn.Linear(3, 10), nn.ReLU(), nn.Linear(10, 5))
trainloader = DataLoader(
    TensorDataset(torch.randn(16, 3), torch.randn((16, 5))), batch_size=3
)
testloader = DataLoader(
    TensorDataset(torch.randn(7, 3), torch.randn((7, 5))), batch_size=3
)
opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)

for _ in range(100):
    for x, y in trainloader:
        opt.zero_grad()
        out = model(x)
        loss = F.mse_loss(out, y)
        loss.backward()
        opt.step()

la = Laplace(
    model,
    likelihood="regression",
    subset_of_weights="last_layer",
    hessian_structure="kron",
    backend=CurvlinopsGGN,
)
la.fit(trainloader)
# la.optimize_prior_precision()

for x, _ in testloader:
    pred_mean, pred_cov = la(x)
    pred_var = torch.diagonal(pred_cov, dim1=-2, dim2=-1)
    print(pred_var)
wiseodd commented 4 months ago

This is actually a limitation of Bayesian linear regression with an isotropic Gaussian prior in general.

Let $f(x) = W \phi(x)$ be the model where $f(x) \in \mathbb{R}^c$, $W \in \mathbb{R}^{c \times d}$, and $\phi(x) \in \mathbb{R}^d$. Let $\mathrm{vec}(W) \sim \mathcal{N}(0, \sigma0^2 I{cd \times cd})$.

Given a dataset $\mathcal{D} = { (x_i, yi) }{i=1}^n$, the (exact) Hessian (equiv. the GGN, Fisher) is:

$$ H = \sum{i=1}^n (\phi(x) \otimes I{c \times c}) I{c \times c} (\phi(x) \otimes I{c \times c})^\top , $$

where the middle matrix is the Hessian of the MSE loss wrt. $f(x)$.

Notice the Kronecker structure: $H$ is block diagonal, consisting of $c$ blocks, each block is identical with other blocks---they all arise from $\sum_i \phi(x_i) \phi(x_i)^\top$. This implies that the diagonal of $H$ also has a repeating block structure. This further implies that the posterior covariance $\Sigma = (H + \sigma_0^2 I)^{-1}$ also has the same structure.

Now, when making a prediction, we compute the predictive covariance $(\phi(x*) \otimes I{c \times c}) \Sigma (\phi(x*) \otimes I{c \times c})^\top$. Notice that we have block structures on all the matrices, implying that the $c \times c$ predictive covariance is diagonal with the same value at each component coordinates.

Note that this happens universally for GGN/Hessian, doesn't matter the factorization (diag, kron, full). It's also clear as to why EF doesn't have this issue (albeit implies the usage of the "wrong" Hessian).

The fix

This limitation is due to the linear model itself under the isotropic prior. So, you can simply move away from that prior. In Laplace, this is easy:

la = Laplace(model, "regression", subset_of_weights="last_layer", ...)
la.optimize_prior_precision(prior_structure="diag")
wiseodd commented 4 months ago

Sorry @ArturPrzybysz for taking 2 years to answer this issue!

ArturPrzybysz commented 4 months ago

@wiseodd no problem, this package helped me a ton with my MSc thesis anyway!