aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
470 stars 72 forks source link

Library breaks double precision #246

Open joemrt opened 1 month ago

joemrt commented 1 month ago

I noticed that FullLaplace produces float results even when only objects of double precision are being used. Here is a quick example

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import laplace

dtype = torch.float64
X = torch.randn((100,3), dtype=dtype)
Y = torch.randn((100,3), dtype=dtype)
data = TensorDataset(X,Y)
dataloader = DataLoader(data, batch_size=10)
model = nn.Linear(3,3, dtype=dtype)

full_la = laplace.Laplace(model=model, subset_of_weights='all',
        likelihood='regression', hessian_structure='full')
full_la.fit(dataloader)
print(full_la.H.dtype) # prints torch.float32 (at least on my machine) 

Interestingly when using KFAC instead of the full hessian (i.e. when full above is replaced by kron) the hessian la.H.to_matrix() is of dtype float.64.

wiseodd commented 1 month ago

Thanks @joemrt! I confirm via this test https://github.com/aleximmer/Laplace/commit/c182504070e3ae47e3ba374a009846f9b5c12f46 that this is indeed an unintended behavior.

joemrt commented 1 month ago

Great, thanks @wiseodd! On first glance it appears to me that you forgot to pass the backend to laplace in the test