aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
458 stars 71 forks source link

RunningNLLMetric returns infinite when running optimize_prior_precision #194

Closed MJordahn closed 3 months ago

MJordahn commented 3 months ago

Discussion started in #193.

After running fit (using the AsdlGGN backend) in FullLLLaplace, I can do predictions using the Laplace module (and see that my predictive function changes in comparison to my MAP model - but in a reasonable way). When I then run optimize_prior_precision using the gridsearch method, the models predictive function breaks.

When I use the progress_bar option in the Github version of the repo I get the following (just a snippet):

  0%|          | 0/21 [00:00<?, ?it/s]
[Grid search | prior_prec: 1.000e-04, loss: inf]:   0%|          | 0/21 [00:22<?, ?it/s]
[Grid search | prior_prec: 1.000e-04, loss: inf]:   5%|▍         | 1/21 [00:22<07:39, 22.97s/it]
[Grid search | prior_prec: 2.512e-04, loss: inf]:   5%|▍         | 1/21 [00:44<07:39, 22.97s/it]
[Grid search | prior_prec: 2.512e-04, loss: inf]:  10%|▉         | 2/21 [00:44<07:00, 22.12s/it]
[Grid search | prior_prec: 6.310e-04, loss: inf]:  10%|▉         | 2/21 [01:07<07:00, 22.12s/it]
[Grid search | prior_prec: 6.310e-04, loss: inf]:  14%|█▍        | 3/21 [01:07<06:43, 22.43s/it]
[Grid search | prior_prec: 1.585e-03, loss: inf]:  14%|█▍        | 3/21 [01:30<06:43, 22.43s/it]
[Grid search | prior_prec: 1.585e-03, loss: inf]:  19%|█▉        | 4/21 [01:30<06:23, 22.57s/it]
[Grid search | prior_prec: 3.981e-03, loss: inf]:  19%|█▉        | 4/21 [01:52<06:23, 22.57s/it]
[Grid search | prior_prec: 3.981e-03, loss: inf]:  24%|██▍       | 5/21 [01:52<06:02, 22.64s/it]
[Grid search | prior_prec: 1.000e-02, loss: inf]:  24%|██▍       | 5/21 [02:15<06:02, 22.64s/it]
[Grid search | prior_prec: 1.000e-02, loss: inf]:  29%|██▊       | 6/21 [02:15<05:40, 22.70s/it]
[Grid search | prior_prec: 2.512e-02, loss: inf]:  29%|██▊       | 6/21 [02:38<05:40, 22.70s/it]
[Grid search | prior_prec: 2.512e-02, loss: inf]:  33%|███▎      | 7/21 [02:38<05:18, 22.73s/it]
[Grid search | prior_prec: 6.310e-02, loss: inf]:  33%|███▎      | 7/21 [03:01<05:18, 22.73s/it]
[Grid search | prior_prec: 6.310e-02, loss: inf]:  38%|███▊      | 8/21 [03:01<04:55, 22.75s/it]
[Grid search | prior_prec: 1.585e-01, loss: inf]:  38%|███▊      | 8/21 [03:24<04:55, 22.75s/it]
[Grid search | prior_prec: 1.585e-01, loss: inf]:  43%|████▎     | 9/21 [03:24<04:33, 22.77s/it]

When I compute NLL using Pytorch NLLLoss, and using the model that has only been fitted but not had the prior optimized, the outputs are reasonable.

I have just seen issue #157 which I suppose is related to this problem.

MJordahn commented 3 months ago

I have found a solution for my case. The problem was stemming from the fact that the passed loss is never moved to the same device that the rest of the Laplace class is on. It was silently failing due to the try statement as mentioned in #157.

My solution was just to move the loss function class to the Laplace module's own device (_device) inside optimize_prior_precision_base of BaseLaplace (If loss is None I just instantiate RunningNLLMetric directly on _device, otherwise I move it there).

I am not sure if it is a pretty solution but it works for me now (at least I am not getting inf values when running optimize_prior_precision).

wiseodd commented 3 months ago

Thanks! Yes, that try statement is indeed insidious. Maybe letting it fails (i.e. remove the try-except) is a better design. @runame, thoughts?

runame commented 3 months ago

Yea removing the try-except statement sounds like the way to go, or alternatively restricting which error will be caught and not catching any RuntimeError.