cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.53k stars 554 forks source link

[Question] LOO-CV with CRPS #1326

Open jrojsel opened 3 years ago

jrojsel commented 3 years ago

Hi,

Is it true that each element of mu and sigma2 in leave_one_out_pseudo_likelihood.py represents what would have been the mean and variance of the posterior distribution if the corresponding data points would have been left out?

If so, I am curious about your thoughts on modifiying this mll so that it instead of returning the mll, it will return the summed CRPS based on mu and sigma2. According to Gneiting & Raftery, 2007, Eq. 21, we could replace line 63 and forward by

pdf_term = (
    torch.distributions.Normal(0, 1)
    .log_prob((target - mu) / sigma2.sqrt())
    .exp()
)
cdf_term = torch.distributions.Normal(0, 1).cdf((target - mu) / sigma2.sqrt())
crps = sigma2.sqrt() * (
    1.0 / math.sqrt(math.pi)
    - 2 * pdf_term
    - ((target - mu) / sigma2.sqrt()) * (2 * cdf_term - 1)
)
res = crps.sum(dim=-1)
return res

Any thoughts?

jacobrgardner commented 3 years ago

So, the current implementation should stay as it is because it's an implementation of the idea as it appears in R&W.

If you want to open a PR adding a separate loss using CRPS, we could look at that though?

dme65 commented 3 years ago

Is it true that each element of mu and sigma2 in leave_one_out_pseudo_likelihood.py represents what would have been the mean and variance of the posterior distribution if the corresponding data points would have been left out?

In addition to Jake's answer: Yes, element i in mu and sigma2 correspond to the predictive mean and variance when sample i has been removed.

jrojsel commented 3 years ago

Ok thanks! I'll do some experiments, and if the CRPS thing turns out to be a good idea, I'll open a PR. Closing for now.

jrojsel commented 3 years ago

@dme65 @jacobrgardner ok so I did actual leave-one-out cross-validation by training a model 47 times on a noisy, 30-dimensional dataset with 47 data points obtained from physical experiments, each time with the exact mll, the pseudo loocv mll, and the loocv crps "mll" respectively.

There were four outputs (observables), and so I repeated the procedure for each observable. The loocv pseudo mll did perform a little better than the exact mll function, but the loocv crps approach was much better -- there is a clear difference between loocv crps (solid lines) and pseudo loocv mll (dashed lines). I did not even plot results from the exact mll runs since the results cannot even be distinguished from the results from the loocv pseudo mll runs in the below figure. The blue, orange and green observables are very noisy, while the red one is expected to be less noisy.

image

Based on these results, I think we should add the loocv crps approach to GPyTorch -- but I don't know how. We could add a "crps" kwarg to the constructor in the loocv pseudo mll class (defaulting to False), and then if self.crps==True then the above lines are evaluated instead of the current ones in the forward call. Another option would be split the loocv pseudo mll into an abstract base class from which specific loocv pseudo mll and loocv crps "mll" classes would inherit common lines of code.

Balandat commented 3 years ago

Interesting. If we want to allow for general scoring rules in the LOO setting we could also spec out some basic API for computing the loss per leave-out sample (in a batched fashion) so that this can be passed in easily. We can default this to do the llh.

In the figures, it looks like while the CPRS based does compress the errors towards zero, there are also some larger outliers compared to the LOOCV (e.g. @18 for blue, and @-12 for green). Is there an intuitive interpretation for this?

jrojsel commented 3 years ago

Interesting. If we want to allow for general scoring rules in the LOO setting we could also spec out some basic API for computing the loss per leave-out sample (in a batched fashion) so that this can be passed in easily. We can default this to do the llh.

Sounds good!

In the figures, it looks like while the CPRS based does compress the errors towards zero, there are also some larger outliers compared to the LOOCV (e.g. @18 for blue, and @-12 for green). Is there an intuitive interpretation for this?

Yes, there are actually two obvious outliers in the dataset: one experiment rendered values of the observables which differed substantially from the rest, and there is no surprise that the model(s) had a hard time predicting these values -- this experiment corresponds to the most extreme bumps in each plot. I used a linear mean, and a small difference in the slope could have caused the discrepancy between loocv psuedo likelihood and loocv crps (I should rerun the entire thing multiple times just to be sure...) The other outlier @-10 for red is just a noisy experiment.