exoplanet-dev / celerite2

Fast & scalable Gaussian Processes in one dimension
https://celerite2.readthedocs.io
MIT License
68 stars 10 forks source link

Issues sampling with PyMCv5 (gp.marginal fails when combined modelling GP+transit) #124

Open hposborn opened 3 months ago

hposborn commented 3 months ago

I am having issues using celerite2 with PYMC. In the past (with PyMC3) I always used gp.marginal(observed=y-extra_model) in order to sample models which included both GP and other variables (i.e. a transit model) and this had no issue. For whatever reason that is no longer the case with PyMCv5 and I get TypeError: Variables that depend on other nodes cannot be used for observed data.. I thought an easy alternative would be to initialise with gp.compute(), generate a predicted GP curve with gp.predict(), and then model everything the "classical" way in PyMC using pm.Normal(mu=gp_pred+extra_model, sigma=y_err, observed=y). But this gives completely different, and horrendously overfitted, results from using gp.marginal() for the same model. (see below)

So I would love some advice on how to model combined celerite + additional functions: a) Is there any way to sample using gp.marginal() where the observed data can depend on other PyMC parameters? For example, maybe the mean function could be more than a single value but to have N_t values and we can put the transit model in that way? b) How should sampling within PyMC be done if using gp.marginal() with y-extra_model is not possible? Should we be using gp.predict() for this purpose at all, or is there just a step I'm missing which is causing the drastic overfitting?

Some code as a MWE:

import pymc as pm
import pymc_ext as pmx
import celerite2.pymc
import arviz as az

import exoplanet as xo

import numpy as np
import matplotlib.pyplot as plt

#Initialising some sinusoidal terms to act as something for GP to remove:
sin_amps=np.exp(np.random.normal(-3,0.2,5))
sin_t0s=np.random.normal(0,15,5)
sin_pers=np.exp(np.random.normal(2,0.5,5))

#Initialising transit parameters:
i_Rs=0.8;i_Ms=0.76
i_us=np.array([0.1,0.3])
i_t0=3.197652;i_P=12.59219 #days
i_b=0.393
i_rpl=3.1309 #Rearth
i_rprs=i_rpl/109.2*i_Rs

#Creating fake data by doing LimbDarkLightCurve
t=np.arange(0,50,1/50)
flux_err=np.tile(0.15,2500)
pure_flux = 1000*xo.LimbDarkLightCurve(i_us).get_light_curve(orbit=xo.orbits.KeplerianOrbit(r_star=i_Rs,m_star=i_Ms,period=i_P,t0=i_t0,b=i_b), r=i_rprs*i_Rs, t=t).eval()[:,0] + \
                      np.sum(sin_amps[:,None]*np.sin(2*np.pi*(t[None,:]-sin_t0s[:,None])/sin_pers[:,None]),axis=0)
flux=pure_flux+np.random.normal(0.0,np.nanmedian(flux_err),2500)

#Plotting to check:
plt.plot(t,flux,'.')
plt.plot(t,pure_flux,'--',alpha=0.7)

MWE_true_variation

The anticipated behaviour, using gp.marginal() (no transit):

with pm.Model() as model:
    logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)

    #Initialising GP:
    sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
                                                              upper=np.ptp(flux),target=0.01))
    w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
                                                              upper=(2*np.pi)/0.2,target=0.01))
    kernel = celerite2.pymc.terms.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))

    mean_flux = pm.Normal("mean", mu=0.0, sigma=0.5*np.nanstd(flux))
    gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux)
    gp.compute(t, yerr=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), quiet=True)
    loglik=gp.marginal("loglik", observed=flux)# - light_curve)

    gp_pred=pm.Deterministic("gp_pred",gp.predict(flux,return_var=False))
    wmarg_init_soln=pm.find_MAP()
    wmarg_trace=pm.sample(start=wmarg_init_soln)

plt.plot(t,flux,'.',alpha=0.6)
plt.plot(t,np.nanmedian(wmarg_trace.posterior['gp_pred'],axis=(0,1)),'-')
plt.savefig("MWE_fit_wmarg.png")

MWE_fit_wmarg

The arviz summary:

    mean    sd  hdi_3%  hdi_97%     mcse_mean   mcse_sd     ess_bulk    ess_tail    r_hat
logjit  -3.097  0.194   -3.436  -2.759  0.004   0.003   2858.0  2106.0  1.0
mean    -0.003  0.024   -0.048  0.045   0.001   0.000   2031.0  1516.0  1.0
sigma   0.119   0.017   0.089   0.150   0.000   0.000   1796.0  1514.0  1.0
w0  1.599   0.428   0.916   2.432   0.011   0.008   1431.0  1696.0  1.0

So a GP-only model works fine.

The behaviour when including an additional non-celerite mean function (with exoplanet transit):

with pm.Model() as model:
    Rs=pm.Normal("Rs",mu=0.8,sigma=0.02)
    Ms=pm.Normal("Ms",mu=0.78,sigma=0.02)
    P=pm.Normal("P",mu=12.6,sigma=0.01)
    t0=pm.Normal("t0",mu=3.21,sigma=0.04)
    log_rprs=pm.Normal("log_rprs",mu=-4,sigma=3)
    rprs=pm.Deterministic("rprs",pm.math.exp(log_rprs))
    rpl=pm.Deterministic("rpl",rprs*Rs*109.2)
    b=xo.distributions.ImpactParameter("b",ror=rprs)
    orb = xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b)
    u_stars = xo.distributions.QuadLimbDark("u_star", testval=np.array([0.3, 0.2]))
    light_curve = 1000*xo.LimbDarkLightCurve(u_stars).get_light_curve(orbit=xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b), r=rprs*Rs, t=t)

    logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)

    #Initialising GP:
    sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
                                                              upper=np.ptp(flux),target=0.01))
    w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
                                                              upper=(2*np.pi)/0.2,target=0.01))
    kernel = celerite2.pymc.terms.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))

    mean_flux = pm.Normal("mean", mu=0.0, sigma=0.5*np.nanstd(flux))
    gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux)
    gp.compute(t, yerr=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), quiet=True)
    loglik=gp.marginal("loglik", observed=flux - light_curve)
    pm.find_MAP()

Output:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[25], line 26
     24 gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux)
     25 gp.compute(t, yerr=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), quiet=True)
---> 26 loglik=gp.marginal("loglik", observed=flux - light_curve)
     27 pm.find_MAP()

File ~/miniconda3/envs/chx/lib/python3.9/site-packages/celerite2/pymc/celerite2.py:96, in GaussianProcess.marginal(self, name, **kwargs)
     93 from celerite2.pymc.distribution import CeleriteNormal
     95 self._add_citations_to_pymc_model(**kwargs)
---> 96 return CeleriteNormal(
     97     name,
     98     self._mean_value,
     99     self._norm,
    100     self._t,
    101     self._c,
    102     self._U,
    103     self._W,
    104     self._d,
    105     **kwargs
    106 )

File ~/miniconda3/envs/chx/lib/python3.9/site-packages/pymc/distributions/distribution.py:413, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
    409         kwargs["shape"] = tuple(observed.shape)
    411 rv_out = cls.dist(*args, **kwargs)
--> 413 rv_out = model.register_rv(
    414     rv_out,
    415     name,
    416     observed,
    417     total_size,
    418     dims=dims,
    419     transform=transform,
    420     initval=initval,
    421 )
    423 # add in pretty-printing support
    424 rv_out.str_repr = types.MethodType(str_for_dist, rv_out)

File ~/miniconda3/envs/chx/lib/python3.9/site-packages/pymc/model/core.py:1265, in Model.register_rv(self, rv_var, name, observed, total_size, dims, transform, initval)
   1252 else:
   1253     if (
   1254         isinstance(observed, Variable)
   1255         and not isinstance(observed, GenTensorVariable)
   (...)
   1263         and not is_minibatch(observed)
   1264     ):
-> 1265         raise TypeError(
   1266             "Variables that depend on other nodes cannot be used for observed data."
   1267             f"The data variable was: {observed}"
   1268         )
   1270     # `rv_var` is potentially changed by `make_obs_var`,
   1271     # for example into a new graph for imputation of missing data.
   1272     rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size)

TypeError: Variables that depend on other nodes cannot be used for observed data.The data variable was: Sub.0

I have verified that the same error occurs across different computers (both my M2 Mac and linux server).

The behaviour when sampling with the output of gp.predict():

with pm.Model() as model:
    logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)

    #Initialising GP:
    sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
                                                              upper=np.ptp(flux),target=0.01))
    w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
                                                              upper=(2*np.pi)/0.2,target=0.01))
    kernel = celerite2.pymc.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))

    mean_flux = pm.Normal("mean", mu=0.0, sigma=0.5*np.nanstd(flux))
    gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux)
    gp.compute(t, yerr=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), quiet=True)
    gp_pred=pm.Deterministic("gp_pred",gp.predict(flux, return_var=False))
    loglik=pm.Normal("loglik", mu=gp_pred, sigma=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), observed=flux)# - light_curve)

    nomarg_init_soln=pm.find_MAP()
    nomarg_trace=pm.sample(start=nomarg_init_soln)

plt.plot(t,flux,'.',alpha=0.6)
plt.plot(t,np.nanmedian(nomarg_trace.posterior['gp_pred'],axis=(0,1)),'-')
plt.savefig("MWE_fit_nomarg.png")

MWE_fit_nomarg

The arviz summary:

    mean    sd  hdi_3%  hdi_97%     mcse_mean   mcse_sd     ess_bulk    ess_tail    r_hat
logjit  -5.297  0.391   -6.014  -4.597  0.009   0.006   2263.0  1961.0  1.00
mean    -0.000  0.084   -0.150  0.164   0.002   0.001   2608.0  2583.0  1.00
sigma   1.559   0.874   0.652   2.833   0.030   0.021   1389.0  833.0   1.00
w0  163.862     205.731     21.684  393.793     7.451   5.271   992.0   849.0   1.01

This is clearly extremely over-fitted for some reason...

hposborn commented 3 months ago

Ok, it looks like using pm.Potential(gp.log_likelihood(y-extra_model)) is the way to go:

with pm.Model() as model:
    Rs=pm.Normal("Rs",mu=0.8,sigma=0.02)
    Ms=pm.Normal("Ms",mu=0.78,sigma=0.02)
    P=pm.Normal("P",mu=12.6,sigma=0.01)
    t0=pm.Normal("t0",mu=3.21,sigma=0.04)
    log_rprs=pm.Normal("log_rprs",mu=-4,sigma=3,initval=-2)
    rprs=pm.Deterministic("rprs",pm.math.exp(log_rprs))
    rpl=pm.Deterministic("rpl",rprs*Rs*109.2)
    b=xo.distributions.ImpactParameter("b",ror=rprs,initval=0.4)
    orb = xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b)
    u_stars = xo.distributions.QuadLimbDark("u_star", testval=np.array([0.3, 0.2]))
    lightcurve=pm.Deterministic('lightcurve',1000*xo.LimbDarkLightCurve(u_stars).get_light_curve(orbit=xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b), r=rprs*Rs, t=t))
    logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)

    #Initialising GP:
    sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
                                                              upper=np.ptp(flux),target=0.01))
    w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
                                                              upper=(2*np.pi)/0.2,target=0.01))
    kernel = celerite2.pymc.terms.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))

    mean_flux = pm.Normal("mean_flux", mu=0.0, sigma=0.5*np.nanstd(flux))
    gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux, t=t, diag=flux_err** 2 + pm.math.exp(logjit)**2)
    loglik=pm.Potential("loglik", gp.log_likelihood(flux-pm.math.sum(lightcurve,axis=1)))
    gp_pred=pm.Deterministic("gp_pred", gp.predict(flux-pm.math.sum(lightcurve,axis=1), return_var=False))

    #nomarg_trans_init_soln=pm.find_MAP()
    nomarg_trans_trace=pm.sample()

plt.plot(t,flux,'.',alpha=0.6)
plt.plot(t,np.nanmedian(nomarg_trans_trace.posterior['gp_pred'].values+nomarg_trans_trace.posterior['lightcurve'].values[:,:,:,0],axis=(0,1)),'-')
plt.savefig("MWE_fit_nomarg_trans.png")

MWE_fit_nomarg_trans

So that seems to fix it! Though I am apprehensive about this as having a bit of a blackbox likelihood function - sometimes that doesn't play ball with some arviz functions like WAIC, so any advice on directly calling pm.Normal or gp.marginal would still be useful imho.