aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
468 stars 72 forks source link

Confusing behavior for one-dimensional targets #235

Closed joemrt closed 1 month ago

joemrt commented 2 months ago

I noticed that for one-dimensional targets the library produces erroneous results if the targets miss a target dimensional (i.e. only have a batch dimension). Here is a quick example with a linear model for which the Hessian is straightforward to compute:

from torch.utils.data import TensorDataset, DataLoader
import laplace

N=10 # number of datapoints
in_d=5 # input dimension

# generate dataset
X=torch.randn((N,in_d))
Y=torch.randn((N,)) # no target dimension
data = TensorDataset(X,Y)
dataloader = DataLoader(data, batch_size=5)

# linear net
model = torch.nn.Linear(in_d,1,bias=False)

# Laplace Approximation
la = laplace.FullLaplace(model=model, likelihood='regression')
la.fit(dataloader)

# Compare la.H with theoretical value for Hessian
H_theory = torch.einsum('ij,ik->jk',X,X)
print(torch.all(torch.isclose(la.H, H_theory))) # will yield False

Once you modify Y=torch.randn((N,)) to Y=torch.randn((N,1)) the correct Hessian is returned.

While PyTorch itself throws indeed a warning, it's really easy to overlook when lots of other output is produced. An error message from the Laplace Library would have saved me a lot of time.

wiseodd commented 2 months ago

Thanks a lot for reporting this!