Closed MJordahn closed 3 months ago
I have found a solution for my case. The problem was stemming from the fact that the passed loss
is never moved to the same device that the rest of the Laplace
class is on. It was silently failing due to the try
statement as mentioned in #157.
My solution was just to move the loss function class to the Laplace module's own device (_device
) inside optimize_prior_precision_base
of BaseLaplace
(If loss
is None I just instantiate RunningNLLMetric
directly on _device
, otherwise I move it there).
I am not sure if it is a pretty solution but it works for me now (at least I am not getting inf
values when running optimize_prior_precision
).
Thanks! Yes, that try
statement is indeed insidious. Maybe letting it fails (i.e. remove the try-except
) is a better design. @runame, thoughts?
Yea removing the try-except statement sounds like the way to go, or alternatively restricting which error will be caught and not catching any RuntimeError
.
Discussion started in #193.
After running
fit
(using theAsdlGGN
backend) inFullLLLaplace
, I can do predictions using the Laplace module (and see that my predictive function changes in comparison to my MAP model - but in a reasonable way). When I then runoptimize_prior_precision
using thegridsearch
method, the models predictive function breaks.When I use the
progress_bar
option in the Github version of the repo I get the following (just a snippet):When I compute NLL using Pytorch NLLLoss, and using the model that has only been fitted but not had the prior optimized, the outputs are reasonable.
I have just seen issue #157 which I suppose is related to this problem.