aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
472 stars 73 forks source link

3D input tensors and feature reduction #252

Open wmloh opened 3 days ago

wmloh commented 3 days ago

@wiseodd Tldr; Issue with tensors of size (B, L, D) passing through a Linear last layer

Here's the minimal reproducible example:

import torch
import torch.nn as nn
from laplace import Laplace
from torch.utils.data import DataLoader
from tensordict import TensorDict

BATCH_SIZE = 4  # B
SEQ_LENGTH = 6  # L
EMBED_DIM = 8  # D
INPUT_KEY = "input"
OUTPUT_KEY = "output"

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = nn.MultiheadAttention(EMBED_DIM, num_heads=1)
        self.final_layer = nn.Linear(EMBED_DIM, 1)

    def forward(self, x):
        x = x[INPUT_KEY].view(-1, SEQ_LENGTH, EMBED_DIM)  # (B, L, D) 
        out = self.attn(x, x, x, need_weights=False)[0]  # (B, L, D)
        return self.final_layer(out).squeeze(dim=-1)  # (B, L)

ds = TensorDict({INPUT_KEY: torch.randn((100, SEQ_LENGTH * EMBED_DIM)),
                 OUTPUT_KEY: torch.randn((100, SEQ_LENGTH * 1))},
                batch_size=[100])  # simulates a dataset
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda x: x)

model = Model()
la = Laplace(model, "regression", dict_key_x=INPUT_KEY, dict_key_y=OUTPUT_KEY,
             last_layer_name="final_layer", feature_reduction="average")
la.fit(dl)

data = next(iter(dl))  # data[INPUT_KEY].shape = (B, L * D)
pred_map = model(data)  # (B, D)
pred_lap = la(data)  # TODO: error! (shape '[4, 6, 54]' is invalid for input of size 216)

My goal is to obtain some measure of epistemic uncertainty on each "token" of the output. The difference from the tutorials I reviewed is the 3D input tensor, which I need for attention. Using the feature_reduction parameter seems to help a little when I was trying to debug but I'm not very familiar with this functionality.

I find it surprising that la.fit(dl) works but the forward call la(data) doesn't work. How do you recommend I use this library properly for this use-case? Thanks in advance.

wiseodd commented 1 day ago
  1. For better flexibility, it's better to use la(..., subset_of_weights="all", ...) and then just switch of the gradients of the parameters you don't need. E.g. if you want to do last-layer Laplace, keep that subset_of_weights="all" but set gradients of all but the last-layer params to False.
  2. Currently, the GLM prediction doesn't work well with multi-dim batch outputs. But MC approx. works well: set pred_type="nn", link_approx="mc" when making prediction.
  3. Don't squeeze the output dimensions, esp. the class dimension, in your model. Otherwise laplace-torch won't know how many outputs/classes you have.

I will try to make (2) above work for GLM.

For now, this script works:

import torch
import torch.nn as nn
from tensordict import TensorDict
from torch.utils.data import DataLoader

from laplace import Laplace
from laplace.curvature.asdl import AsdlGGN
from laplace.utils.enums import LinkApprox, PredType

BATCH_SIZE = 4  # B
SEQ_LENGTH = 6  # L
EMBED_DIM = 8  # D
INPUT_KEY = "input"
OUTPUT_KEY = "output"

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = nn.MultiheadAttention(EMBED_DIM, num_heads=1)
        self.final_layer = nn.Linear(EMBED_DIM, 1)

    def forward(self, x):
        x = x[INPUT_KEY].view(-1, SEQ_LENGTH, EMBED_DIM)  # (B, L, D)
        out = self.attn(x, x, x, need_weights=False)[0]  # (B, L, D)
        return self.final_layer(out)  # (B, L, 1)

ds = TensorDict(
    {
        INPUT_KEY: torch.randn((100, SEQ_LENGTH, EMBED_DIM)),
        OUTPUT_KEY: torch.randn((100, SEQ_LENGTH, 1)),
    },
    batch_size=[100],
)  # simulates a dataset
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda x: x)

model = Model()

for mod_name, mod in model.named_modules():
    if mod_name == "final_layer":
        for p in mod.parameters():
            p.requires_grad = True
    else:
        for p in mod.parameters():
            p.requires_grad = False

la = Laplace(
    model,
    "regression",
    hessian_structure="diag",
    subset_of_weights="all",
    backend=AsdlGGN,
    dict_key_x=INPUT_KEY,
    dict_key_y=OUTPUT_KEY,
)
la.fit(dl)

data = next(iter(dl))  # data[INPUT_KEY].shape = (B, L * D)
pred_map = model(data)  # (B, D)
pred_la_mean, pred_la_var = la(
    data, pred_type=PredType.NN, link_approx=LinkApprox.MC, n_samples=10
)

# torch.Size([4, 6, 1]) torch.Size([4, 6, 1])
print(pred_la_mean.shape, pred_la_var.shape)
wiseodd commented 1 day ago

Addendum: If you branch glm_multidim, then you can use the GLM predictive (better than MC) with caveats:

See example: https://github.com/aleximmer/Laplace/blob/glm-multidim/examples/lm_example.py

wmloh commented 1 day ago

@wiseodd

Thanks for correcting the code. It works on my end, and I've successfully transferred the correction to my complete use-case (at least with respect to the mentioned issue).

Regarding point (1), oddly enough, if I stubbornly insist on setting subset_of_weights="last_layer", it will not work. It needs to be "all". Regardless, it's nothing major.

wiseodd commented 1 day ago

As I said before, subset_of_weights="last_layer" is much less flexible. Just set it to "all" and switch off gradients

wiseodd commented 1 day ago

I‘ll keep this open until the aforementioned branch merged. Thanks for opening the issue!