exoplanet-dev / celerite2

Fast & scalable Gaussian Processes in one dimension
https://celerite2.readthedocs.io
MIT License
70 stars 11 forks source link

Term.get_coefficients fails for jax implementation when built from source #75

Open bmorris3 opened 1 year ago

bmorris3 commented 1 year ago

Hi @dfm,

Today I've built celerite2 from source following the recommendations on the install docs. I'm trying to do something simple, like this

from celerite2.jax import terms

sho = terms.SHOTerm(S0=1, w0=3000, Q=0.6)
sho.get_coefficients()

but I'm getting the following error

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In [1], line 4
      1 from celerite2.jax import terms
      3 sho = terms.SHOTerm(S0=1, w0=3000, Q=0.6)
----> 4 sho.get_coefficients()

File ~/git/celerite2/python/celerite2/jax/terms.py:36, in Term.get_coefficients(self)
     35 def get_coefficients(self):
---> 36     raise NotImplementedError("subclasses must implement this method")

NotImplementedError: subclasses must implement this method

At first I thought this could be an accident of the multiple SHOTerm implementations, for example, here

https://github.com/exoplanet-dev/celerite2/blob/e75dd45ca4f033b22d3ca19d0545a097c7441495/python/celerite2/jax/terms.py#L473-L478

and here

https://github.com/exoplanet-dev/celerite2/blob/e75dd45ca4f033b22d3ca19d0545a097c7441495/python/celerite2/jax/terms.py#L481

but commenting the first one out doesn't solve the problem.

Any ideas? Thanks!

bmorris3 commented 1 year ago

Writing an issue is always the most clarifying exercise.

I was hitting this error because I was using TermConvolution, which uses this line:

https://github.com/exoplanet-dev/celerite2/blob/e75dd45ca4f033b22d3ca19d0545a097c7441495/python/celerite2/jax/terms.py#L301-L310

but the jax implementation of SHOTerm has under/overdamped versions of the coefficients defined separately. Should TermConvolution be modified accordingly?

dfm commented 1 year ago

Great question! This is an issue introduced by this PR: https://github.com/exoplanet-dev/celerite2/pull/68

I should probably spend some time thinking about how to fix this properly, but one option for the short term (it'll take a minor? performance hit) would be something like:

import jax
import jax.numpy as jnp
from celerite2.jax import terms

def custom_sho_get_coeffs(self):
    ar, cr = self.get_overdamped_coefficients()
    ac, bc, cc, dc = self.get_underdamped_coefficients()
    cond = jnp.less(self.Q, 0.5)
    selectr = lambda x: jax.lax.cond(cond, lambda y: y, jnp.zeros_like, operand=x)
    selectc = lambda x: jax.lax.cond(cond, jnp.zeros_like, lambda y: y, operand=x)
    return selectr(ar), selectr(cr), selectc(ac), selectc(bc), selectc(cc), selectc(dc)

terms.SHOTerm.get_coefficients = custom_sho_get_coeffs
sho = terms.SHOTerm(S0=1, w0=3000, Q=0.6)
sho.get_coefficients()
bmorris3 commented 1 year ago

I can confirm the fix above works for me for now. Thanks Dan!