aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
436 stars 63 forks source link

Laplace predictive uncertainty remaining almost constant across samples #142

Closed nilsleh closed 4 months ago

nilsleh commented 5 months ago

I am working with an Image Regression dataset (image input, single scalar targets), specifically the Tropical Cyclone dataset and want to use Laplace as an UQ-Method. I have tried various different settings (last layer and subnetwork), however, the predictive uncertainty returned by Laplace remains almost constant, for example using the default parameters with last layer laplace, I get:

array([1.0037074, 1.0050274, 1.0049183, 1.0054605, 1.0041007, 1.0071181,
       1.0056504, 1.0059941, 1.0050477, 1.0070947, 1.0063754, 1.0042056,
       1.004853 , 1.0043836, 1.0043393, 1.005228 , 1.0033851, 1.003367 ,
       1.0051998, 1.0050663, 1.0061045, 1.0041615, 1.0053991, 1.0045779,
       1.0033392, 1.0052227, 1.0065176, 1.0044383, 1.0048069, 1.0052433,
       1.0072025, 1.0034157, 1.0053388, 1.0066772, 1.0046337, 1.0053458,
       1.0048453, 1.0034424, 1.0055923, 1.005143 , 1.0052983, 1.0067523,
       1.005628 , 1.0055794, 1.00653  , 1.0054258, 1.0059325, 1.0059363,
       1.0061395, 1.00672  , 1.0042725, 1.0038388, 1.0042018, 1.0081434])

so effectively almost the same uncertainty regardless of the input sample. I was wondering whether you had any pointers that could explain this behavior?

aleximmer commented 5 months ago

Do you have some details on the network you use and how you apply Laplace to it precisely? I am afraid there are not enough details to figure the issue out but it's definitely unusual to have such uncertainty.

nilsleh commented 5 months ago

Thanks for the quick reply! The framework I am using is an open-source project trying to implement different UQ-Methods and making them available through lightning so they are more accessible and easy to compare. Essentially a lighting wrapper for all sorts of different UQ-Methods and also your Laplace library (meant to ask in a different issue about some implementation details on that).

But to make sure that the results above didn't just stem from a mistake in that wrapper implementation, I also wrote up the experiment without that framework in this gist. The data would have to be downloaded from this HF repo and the pretrained checkpoint is available here. The mean estimates are quiet good, just the predictive uncertainty is unexpected.

In terms of requirements to run that script pip install torchgeo should be everything you need besides your laplace library. Maybe that helps, but also just debugging suggestions would be helpful if setting that up is too time-consuming. Thanks again.

aleximmer commented 5 months ago

The problem is in the uncertainty computation pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item() ** 2) at https://gist.github.com/nilsleh/b8a8cee4e67a3ef683bfe743be794278#file-laplace_example-py-L484. la.sigma_noise defaults to 1 when instantiated and should be set to something better. This is the observation noise of your targets and 1 is the most conservative: when standardizing the data, which you seem to do, this means the model cannot reduce the observation noise. One easy way would be to estimate that using a maximum likelihood estimate. In pseudo code emp_sigma_noise = sqrt(mean(square(target - pred))).

Other than that, it would also make sense to either tune the prior precision using la.optimize_prior_precision() or cross-validating it wrt. your performance target. Lastly, it is generally better to use a training loader with data augmentation turned off to fit Laplace. These two points are to further improve performance but the point above is the reason for the weird uncertainties so fixing that should already help to get different values.

nilsleh commented 5 months ago

I had tried optimizing the sigma_noise as well by

log_prior, log_sigma = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
hyper_optimizer = torch.optim.Adam([log_prior, log_sigma], lr=1e-1)
for i in range(100):
    hyper_optimizer.zero_grad()
    neg_marglik = - la.log_marginal_likelihood(log_prior.exp(), log_sigma.exp())
    neg_marglik.backward()
    hyper_optimizer.step()

however, that still yields similar even predictive uncertainties:

f_sigma = array([0.05167731, 0.04192743, 0.04703024, 0.04773571, 0.04734406,
       0.04734465, 0.04470056, 0.05894397, 0.04164458, 0.04090737,
       0.04194929, 0.04587073, 0.04123243, 0.05245287, 0.04637391,
       0.03705715, 0.0392273 , 0.04476492, 0.05171853, 0.04406474,
       0.04473803, 0.0452574 , 0.03747206, 0.04991284], dtype=float32)

pred_std = array([0.50240684, 0.50149775, 0.5019501 , 0.50201666, 0.5019796 ,
       0.50192714, 0.5019063 , 0.5010613 , 0.50169104, 0.5022994 ,
       0.5015247 , 0.50180167, 0.5019909 , 0.50270176, 0.50156194,
       0.5011141 , 0.50127923, 0.50174296, 0.50241107, 0.50168097,
       0.5017405 , 0.50178707, 0.50114495, 0.5022284 ], dtype=float32)

The augmentations that the training dataloader is using is just the image normalization and resizing, so I think that part should be okay. Computing the channel statistics of a batch yields:

X.mean(dim=[0, 2,3])
tensor([0.0732, 0.0763, 0.0765], device='cuda:0')

and 

X.std(dim=[0, 2,3])
tensor([1.0353, 1.0357, 1.0343], device='cuda:0')

target stats of the batch are

mean = tensor(0.1563)
std = tensor(1.1659)
aleximmer commented 5 months ago

Ok the data loader looks good in that case. Regarding the predictive uncertainties: there seems to be way more relative difference between them now than before, especially for f_sigma, which denotes the model uncertainty. It looks though like the sigma_noise parameter is way too large since the joint uncertainty pred_std is 10x larger than f_sigma. I personally would not recommend using the marginal likelihood after training to get sigma noise. It might be fine for the prior precision but I wouldn't rely on it. It usually requires iterative optimization of marginal likelihood during training so that the regularization can actually impact the model weights.

So I would test setting first sigma_noise using the maximum likelihood as mentioned above and then tune the prior precision using marginal likelihood but only in case you are sure the model has a good fit (not over or underfit).

nilsleh commented 5 months ago

Okay for sigma noise, I was mainly following the docs about fitting and optimizing laplace, where I understood the sigma noise optimization to happen after training?

I now did the following, which is how I understood your recommendation:

la = Laplace(det_model, likelihood="regression", sigma_noise=0.5024)
la.fit(train_loader)
la.optimize_prior_precision(method='marglik')

The model has a decent fit, based on accuracy metrics. It seems that when computing pred_std from f_sigma still yield very small relative differences, since f_sigma**2 sort of "removes the relatives differences" by making the terms all a magnitude smaller. So unless f_sigma is already reasonably large and having relative differences, it will be "squished" during the pred_std at which point the predictive uniform is almost "uniform" regardless of inputs.

One observation was, that tuning the prior_precision yields a fairly large value of la.prior_precision: 353.7385, but also not tuning it gives similar f_sigma results.

aleximmer commented 5 months ago

Yes, in that example it works quite well because it's easy to visualize and we can be sure it's not overfitting. We have to improve the docs on this aspect though, thanks for pointing it out!

Just to understand what you are doing now: the sigma_noise=0.5024 is the empirical standard deviation of your predictions - targets on the training data, right? Then this looks good to me. I think the prior precision value you get is relatively normal. It might also just be that your model is very certain about the mean function (hence low epistemic f_sigma) but the targets are quite noisy (hence high sigma_noise) but this depends on the data and model. Do you think this could be possible? This would happen if the dataset is very easy to fit, e.g., has a simple mean function, but relatively high measurement noise.

nilsleh commented 5 months ago

Yes, that is correct. I suppose that is quiet possible. The label distribution is significantly positively skewed (most image examples are of cyclones with low wind speeds with fewer images from higher wind speeds) so we have noticed that models usually fit a mean function such that errors on lower targets (majority of samples) are quiet low, but there are higher errors on higher target ranges. So there is definitely target noise, but mostly we assume the factor to be an imbalanced target distribution.

We have seem better uncertainty metrics performance from methods like Mean Variance Networks, so potentially your recent work on Heteroskedastic Laplace could be useful?

aleximmer commented 5 months ago

Yes, I think that would be worth a try. We found the predictive and regularization through optimizing marginal likelihood during training to help quite a bit in that case. It does, however, mean that one cannot use a pretrained network. If you only want the predictive, you can use a pretrained network though.