lightkurve / lightkurve

A friendly package for Kepler & TESS time series analysis in Python.
https://lightkurve.github.io/lightkurve/
MIT License
425 stars 171 forks source link

Lightcurve should have a Gaussian process based flatten routine #115

Closed christinahedges closed 3 years ago

christinahedges commented 6 years ago

Savgol is quick but pretty terrible at flattening data. We should instead use celerite and a gaussian process. This is much more principled and will give us errors to propagate.

We can have it accept a mask to remove transits and burn ins due to down time.

It should behave the same as regular flatten, either correcting the light curve or returning a light curve object with the flat trend.

There should be basic switches for the user to select.

It should be possible to give a time scale lower bound, such that the GP doesn't fit any variability on time scales of a planet transit.

christinahedges commented 6 years ago

Here are some examples of a simple GP performing better than a savgol filter.

gpdemo

With a better estimate of the transit depth possible. gpdemo2

However, in order to get a reasonable error, we need to marginalize over the errors in our hyperparameters. This will potentially make run time an issue.

christinahedges commented 6 years ago

A quick demo for how to flatten K2 data with a GP.

Note: This demo is not complete, need to add marginalization over the errors in the hyperparameters of the GP to have realistic errors.

from lightkurve import KeplerLightCurveFile, KeplerLightCurve
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
from tqdm import tqdm
from copy import deepcopy
pl = 'Kepler-102'

df = pd.read_csv('planets_corrected.csv')
df = df[df.pl_hostname==pl].reset_index(drop=True)
df[['pl_tranmid','pl_orbper','pl_radj','pl_orbsmax','st_rad', 'st_mass','pl_trandur']]

Planets = np.arange(len(df))
def find_mask(time, planets):
    '''Masks out planet transits given in list planets
    '''
    mask = np.ones(len(time), dtype=bool)
    for idx in planets:       
        per, t0, dur = df.loc[idx, ['pl_orbper', 'pl_tranmid', 'pl_trandur']]
        if not np.isfinite(dur):
            dur = 10**(line[0]*np.log10(per)+line[1])
        t0 -= per*10000.
        ph = (time - t0)/per % 1
        ok = (ph > ((dur*1.5)/per)/2) & (ph < 1 - ((dur*1.5)/per)/2)
        mask = np.all([mask, ok], axis=0)
        ph = (time + (per/2) - t0)/per % 1
        ok = (ph > ((dur*1.5)/per)/2) & (ph < 1 - ((dur*1.5)/per)/2)
        mask = np.all([mask, ok], axis=0)
    return mask

def find_breaks(time, break_tol=10):
    '''Finds where there are breaks in time with npoints > break_tol
    '''
    dt = time[1:] - time[0:-1]
    breaks = np.where(dt > break_tol*np.nanmin(dt))[0] + 1
    breaks = np.append([0], breaks)
    breaks = np.append(breaks, len(lc.time))
    return breaks
lc = KeplerLightCurve(time=[], flux=[], flux_err=[])
time = np.empty(0)
q = 2
lcf = KeplerLightCurveFile.from_archive('Kepler-102', quarter=q)
lc = lc.append(lcf.PDCSAP_FLUX.normalize())
lc.time = lcf.timeobj.jd
Found 1 File(s)
INFO: Found cached file ./mastDownload/Kepler/kplr010187017_lc_Q111111111111111111/kplr010187017-2009259160929_llc.fits with expected size 466560. [astroquery.query]
fig, ax = plt.subplots(figsize=(15,3))
lc.plot(ax=ax, lw=1)

# Remove any 'burn in' times where there has been a break in data collection
# Also remove the planets!
mask = find_mask(lc.time, Planets)
breaks = find_breaks(lc.time, break_tol=10)
mask &= ~np.in1d(np.arange(0, len(lc.time)), np.asarray([np.arange(b, b+30) for b in breaks]).ravel())

lc[mask].plot(ax=ax, lw=1, c='C3')
<matplotlib.axes._subplots.AxesSubplot at 0x1c1a0abf28>
gpexample_5_1
import celerite
from celerite import terms
from scipy.optimize import minimize

y = lc[mask].remove_nans().flux
err = lc[mask].remove_nans().flux_err
t = lc[mask].remove_nans().time
t -= lc.time[0]

kernel = celerite.terms.Matern32Term(log_sigma=np.log(np.nanstd(lc.flux)), log_rho=-np.log(10.0), bounds = [(-15, 15), (-15, 15)])
kernel += terms.JitterTerm(log_sigma=np.log(np.nanmedian(err)))

gp = celerite.GP(kernel, mean=1, fit_mean=True)
gp.compute(t, err)  # You always need to call compute once.

def neg_log_like(params, y, gp):
    gp.set_parameter_vector(params)
    return -gp.log_likelihood(y)

initial_params = gp.get_parameter_vector()
bounds = gp.get_parameter_bounds()

soln = minimize(neg_log_like, initial_params, method="L-BFGS-B", bounds=bounds, args=(y, gp))
gp.set_parameter_vector(soln.x)
gp.get_parameter_dict()
OrderedDict([('kernel:terms[0]:log_sigma', -7.372043715724431),
             ('kernel:terms[0]:log_rho', 0.5069137405276191),
             ('kernel:terms[1]:log_sigma', -9.843151308837868),
             ('mean:value', 0.9999446557591725)])
# Check that the GP looks reasonable in the first 1000 points of data.
mask = find_mask(lc.time, Planets)
fig, ax = plt.subplots(figsize=(15,3))
lc.plot(ax=ax)

x = lc.time - lc.time[0]
pred_mean, pred_var = gp.predict(y, x, return_var=True)
pred_mean = pred_mean
pred_std = np.sqrt(pred_var)
ax.plot(x+lc.time[0], pred_mean, color='C1', zorder=10)
ax.fill_between(x + lc.time[0], pred_mean + pred_std, pred_mean - pred_std, color='C1', alpha=0.3, edgecolor="none")
<matplotlib.collections.PolyCollection at 0x1c1a4ce198>
gpexample_8_1
# Apply GP to flatten all the data.
# NOTE: We split up the data where there are breaks to reduce run time. 
# Gaps in data collection are natural break points.

# NOTE 2: This is misleading. The TRUE error on the data needs to be found by 
# marginalizing over the errors in the hyperparameters
flat = KeplerLightCurve(time=lc.time, flux=lc.flux*0, flux_err=lc.flux_err*0)
breaks = find_breaks(lc.time, break_tol=10)
for idx in tqdm(range(len(breaks)-1)):
    l = lc[breaks[idx]:breaks[idx+1]]
    x = l.time - lc.time[0]
    pred_mean, pred_var = gp.predict(y, x, return_var=True)
    pred_mean = pred_mean + gp.get_parameter_dict()['mean:value'] - 1
    pred_std = np.sqrt(pred_var)

    flat.flux[breaks[idx]:breaks[idx+1]] = pred_mean
    flat.flux_err[breaks[idx]:breaks[idx+1]] = pred_std
100%|██████████| 5/5 [00:01<00:00,  4.96it/s]
fig, ax = plt.subplots(figsize=(10,3))
flat.plot(ax=ax, lw=2, color='C1', label='Gaussian Process')
ax.fill_between(flat.time, flat.flux + flat.flux_err, flat.flux - flat.flux_err, color='C1', alpha=0.3, edgecolor="none")
lc.flatten(return_trend=True)[1].plot(ax=ax, lw=2, c='C2', label='Savgol Filter')
plt.savefig('GPdemo.png', dpi=150, bbox_inches='tight')
gpexample_10_0
fig, ax = plt.subplots(figsize=(10,3))
lc1 = deepcopy(lc)
lc1.flux -= (flat.flux - 1)
lc1.flux_err = (lc1.flux_err**2 + flat.flux_err**2)**0.5
lc1.plot(ax=ax, lw=2, color='C1', label='Gaussian Process')

lc2 = deepcopy(lc)
lc2.flux -= (lc.flatten(return_trend=True)[1].flux - 1)
lc2.plot(ax=ax, lw=1, color='C2', label='Savgol Filter')
plt.savefig('GPdemo2.png', dpi=150, bbox_inches='tight')
gpexample_11_0
barentsen commented 3 years ago

I'm closing this issue for now because it has been inactive for 2+ years, however it continues to be true that it would be very useful to add a GP- or spline-based detrending feature to Lightkurve. I would very much welcome a PR to add this!