STOR-i / GaussianProcesses.jl

A Julia package for Gaussian Processes
https://stor-i.github.io/GaussianProcesses.jl/latest/
Other
308 stars 53 forks source link

Worse performance after `optimize!(gpe)` #225

Open JanRotti opened 1 year ago

JanRotti commented 1 year ago

Dear all,

This is somewhat related to Issue #221. I was testing around with optimizing hyperparameters using optimize!(gp); I get a lot of LinearAlgebra.PosDefException. When rerunning just optimize!(gp), at some point the errors disappear and the optimization converges. But the model performs quite worse on test data, compared to before. After optimization the mll/target values are actually reduced by the optim.

I would assume if the error occurs it should be consistent and not disappear with rerunning. Is this behavior expected?

Here is a minimal example to reproduce:

using GaussianProcesses 
X = [0.4  1.0  0.6  0.8  0.2; 1.0  0.4  0.2  0.8  0.6]
y = map(sum, eachcol(X))
Xₜₑₛₜ = [0.229267  0.939358  0.538846  0.180355  0.545399; 0.701589  0.149632  0.293056  0.170041  0.423608]
yₜₑₛₜ = map(sum, eachcol(Xₜₑₛₜ))
gp = GPE(X, y, MeanZero(), SE(0.0, 0.0), -6.0)
@info "Before Opt: $(sqrt(sum((GaussianProcesses.predict(gp, Xₜₑₛₜ)[1] .- yₜₑₛₜ).^2)/length(yₜₑₛₜ)))"
function optim()
    try
        optimize!(gp)
    catch
        optim()
    end
end
optim()
@info "After HyperOpt: $(sqrt(sum((GaussianProcesses.predict(gp, Xₜₑₛₜ)[1] .- yₜₑₛₜ).^2)/length(yₜₑₛₜ)))"

Greetings, Jan

maximerischard commented 1 year ago

One thing that is noticeable is that your test example is deterministic – y is just the same of x_1 and x_2, with no added noise. This could explain why the optimizer runs into covariance matrices that are not positive definite, and why you run into trouble with predictions. Two things to try (separately):

JanRotti commented 1 year ago

You were totally right! Having noise=false does solve the problem. Still, I am wondering about the behavior of optimize beforehand. Shouldn't optimize!(gp) at least be consistent in its result? (So that it does not give $n-1$ times PosDefException and then on the $n$-th cycle works.)