I noticed that for one-dimensional targets the library produces erroneous results if the targets miss a target dimensional (i.e. only have a batch dimension). Here is a quick example with a linear model for which the Hessian is straightforward to compute:
from torch.utils.data import TensorDataset, DataLoader
import laplace
N=10 # number of datapoints
in_d=5 # input dimension
# generate dataset
X=torch.randn((N,in_d))
Y=torch.randn((N,)) # no target dimension
data = TensorDataset(X,Y)
dataloader = DataLoader(data, batch_size=5)
# linear net
model = torch.nn.Linear(in_d,1,bias=False)
# Laplace Approximation
la = laplace.FullLaplace(model=model, likelihood='regression')
la.fit(dataloader)
# Compare la.H with theoretical value for Hessian
H_theory = torch.einsum('ij,ik->jk',X,X)
print(torch.all(torch.isclose(la.H, H_theory))) # will yield False
Once you modify Y=torch.randn((N,)) to Y=torch.randn((N,1)) the correct Hessian is returned.
While PyTorch itself throws indeed a warning, it's really easy to overlook when lots of other output is produced. An error message from the Laplace Library would have saved me a lot of time.
I noticed that for one-dimensional targets the library produces erroneous results if the targets miss a target dimensional (i.e. only have a batch dimension). Here is a quick example with a linear model for which the Hessian is straightforward to compute:
Once you modify
Y=torch.randn((N,))
toY=torch.randn((N,1))
the correct Hessian is returned.While PyTorch itself throws indeed a warning, it's really easy to overlook when lots of other output is produced. An error message from the Laplace Library would have saved me a lot of time.