aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
432 stars 61 forks source link

test_full_hess_curvlinops_vs_asdl of "regression" loss failed in Python 3.11 #154

Closed Yuxin-99 closed 3 months ago

Yuxin-99 commented 3 months ago

Hi!

I was trying to install Laplace under Python 3.11 environment and pytorch 2.2.1. I cloned the main branch and built it from the source. All the tests passed except test_full_hess_curvlinops_vs_asdl ["regression"]. The detailed fail information is below. Do I need any other setup steps to make Laplace work properly under Python 3.11? Thank you in advance!

py::test_full_hess_curvlinops_vs_asdl[regression] failed: class_Xy = (tensor([[ 0.3615,  0.9511, -0.2172],
        [ 0.7430, -0.8946,  0.0683],
        [-0.5409, -0.4365,  1.4726],
      ...9],
        [ 0.3897,  1.7161,  0.3971],
        [-1.4383, -0.0673, -1.7935]]), tensor([1, 0, 0, 0, 0, 0, 0, 1, 1, 0]))
reg_Xy = (tensor([[ 0.3615,  0.9511, -0.2172],
        [ 0.7430, -0.8946,  0.0683],
        [-0.5409, -0.4365,  1.4726],
      ...805],
        [-0.7166, -0.0973],
        [ 0.2902, -0.9757],
        [ 1.5647, -2.0214],
        [ 0.3714,  0.3550]]))
model = Sequential(
  (0): Linear(in_features=3, out_features=20, bias=True)
  (1): Tanh()
  (2): Linear(in_features=20, out_features=2, bias=True)
)
loss_type = 'regression'

>   ???
E   assert False
E    +  where False = <built-in method allclose of type object at 0x7f42e16dd8a0>(tensor([[ 0.0328,  0.0850,  0.7302,  ...,  0.0478, -0.0848,  0.1754],\n        [ 0.0850, -0.5154,  0.1735,  ...,  0.626...2687, -0.1320,  ...,  0.0000, 10.0000,  0.0000],\n        [ 0.1754,  0.5561,  0.2731,  ...,  5.0403,  0.0000, 10.0000]]), tensor([[ 0.0328,  0.0850,  0.7302,  ...,  0.0478, -0.0848,  0.1754],\n        [ 0.0850, -0.5154,  0.1735,  ...,  0.626...2687, -0.1320,  ...,  0.0000, 10.0000,  0.0000],\n        [ 0.1754,  0.5561,  0.2731,  ...,  5.0403,  0.0000, 10.0000]]), rtol=5e-05)
E    +    where <built-in method allclose of type object at 0x7f42e16dd8a0> = torch.allclose
runame commented 3 months ago

Hi, this could just be a numerical issue. Can you increase rtol, maybe to 1e-3 and rerun the tests?

Yuxin-99 commented 3 months ago

It is solved. Thanks a lot for your help! :)