exoplanet-dev / celerite2

Fast & scalable Gaussian Processes in one dimension
https://celerite2.readthedocs.io
MIT License
70 stars 11 forks source link

Avoiding LinAlgErrors for closesly sampled x arrays #76

Closed hposborn closed 1 year ago

hposborn commented 1 year ago

Some background: I want to model a GP for variations in flux as a function of roll-angle. Most of our observations (think HST or CHEOPS) last for a few orbits. Each orbit shows an extremely self-similar (and often not particularly sinusoidal) variation due to e.g. Earth limb, lunar glint, temperature effects, etc. Hence, including a GP on roll-angle vs flux is perfect.

My method thus far has been to sort the data into increasing roll angle and then unsort afterwards to get the variation as a function of time. However, roll angle is not a continuous observation, and for some targets we have enough data that some measurements are extremely close in roll angle to others (<1e-4 degrees), yet not particularly correlated (as they're ~days apart in time). In these cases, celerite2 simply breaks - it throws a LinAlgError. I'm not sure of the maths, but likely because of a large difference between two extremely close points, although it seems to break even when the difference between such points are consistent with the variance.

Here is a working (well, non-working) example:

# Ten fake "orbits" with varying roll angle from ~50 to ~250deg
t=np.hstack([np.linspace(50+5*np.random.random(),250+5*np.random.random(),500) for c in range(10)])

with pm.Model():
    rollangle_w0=0.1 #NB - at higher rollangle_w0, we get a different error - a Theano Assert error
    rollangle_S0=120.0
    kern=theano_terms.SHOTerm(S0=rollangle_S0,  w0=rollangle_w0, Q=1/np.sqrt(2))#, mean = phot_mean)
    gp=celerite2.theano.GaussianProcess(kern, np.sort(t), mean=0.0) #Sorting to 
    gp.compute(t=t, diag=np.tile(0.25,len(t))**2)

Relatedly, even when celerite does not break outright, it is consistently forced into extremely short-timescale variations despite a strong prior against such over-fitting...

I think this is all because of an assumption that each x measurement in effectively instantaneous - in our case we have x values (roll angles) which are actually the average across some dx range which overlaps with the neighbouring values. I'm not sure it would be possible, but is there any way to incorporate a value (or array) of dx which would stop this error/overfitting?

Alternatively, is there some way to use a Periodic kernel to enforce similar variation across all points? I've tried adapted the RotationKernel, which works well for low-frequency variation on the order of 100degree wavelength, but is typically unable to model higher-frequency variation with wavelengths the order of 10s of degrees. But maybe there is a combination of a periodic term and e.g. a spikier Matern32 kernel?

dfm commented 1 year ago

There are two issues with your example code here:

  1. When constructing the GaussianProcess object, if you pass the t coordinates, it calls compute internally. However, you're not passing diag to the constructor, meaning that it's calling compute with diag=0. If you pass diag to the constructor, everything seems to work.

  2. Relatedly, you don't need to (/shouldn't!) call compute yourself if you pass t to the constructor, but if you do, you still need to sort the times (they are not in your example code).

I'm not sure if these are the root issues that you're seeing, but fixing those would be a good place to start and see where you get!

dfm commented 1 year ago

(P.S. number 2 is why you're seeing the Assert - that happens when you pass times that are not sorted!)

hposborn commented 1 year ago

Ah yes, this explains it - I was sure that simply increasing the variance with added jitter term in the compute function should be able to stop this problem, but it wasn't making any difference... because I didn't realise the GaussianProcess initialisation pre-computes the GP (bypassing the diag term in the compute function). And yes you're right that I should have put np.sort(t) in the .compute function (but given the code didn't get that far...). Thanks for your help! Happy to close.

dfm commented 1 year ago

No problem at all! The interface is definitely a bit inconsistent.