aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
436 stars 63 forks source link

Error in Multitask regression case #126

Closed UtkarshMitta closed 1 year ago

UtkarshMitta commented 1 year ago

I tried to use this library for multitask regression (multiple perceptrons at output layer), but I get error while fitting train loader into the library.

wiseodd commented 1 year ago

Hey, can you share a minimal working example that reproduces your error?

There shouldn't be any error---e.g. the following works:

import torch
from torch import nn
from laplace import Laplace

# Dummy dataset
dataset = torch.utils.data.TensorDataset(
    torch.rand(256, 2),  # X
    torch.rand(256, 3),  # y
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128)

model = nn.Sequential(
    nn.Linear(2, 10),
    nn.ReLU(),
    nn.Linear(10, 3)
)
la = Laplace(model, likelihood='regression', subset_of_weights='all', hessian_structure='full')
la.fit(dataloader)
la.optimize_prior_precision()
UtkarshMitta commented 1 year ago

That worked. Thank you so much for helping me out.