Open bmorris3 opened 1 year ago
Writing an issue is always the most clarifying exercise.
I was hitting this error because I was using TermConvolution, which uses this line:
but the jax implementation of SHOTerm
has under/overdamped versions of the coefficients defined separately. Should TermConvolution
be modified accordingly?
Great question! This is an issue introduced by this PR: https://github.com/exoplanet-dev/celerite2/pull/68
I should probably spend some time thinking about how to fix this properly, but one option for the short term (it'll take a minor? performance hit) would be something like:
import jax
import jax.numpy as jnp
from celerite2.jax import terms
def custom_sho_get_coeffs(self):
ar, cr = self.get_overdamped_coefficients()
ac, bc, cc, dc = self.get_underdamped_coefficients()
cond = jnp.less(self.Q, 0.5)
selectr = lambda x: jax.lax.cond(cond, lambda y: y, jnp.zeros_like, operand=x)
selectc = lambda x: jax.lax.cond(cond, jnp.zeros_like, lambda y: y, operand=x)
return selectr(ar), selectr(cr), selectc(ac), selectc(bc), selectc(cc), selectc(dc)
terms.SHOTerm.get_coefficients = custom_sho_get_coeffs
sho = terms.SHOTerm(S0=1, w0=3000, Q=0.6)
sho.get_coefficients()
I can confirm the fix above works for me for now. Thanks Dan!
Hi @dfm,
Today I've built celerite2 from source following the recommendations on the install docs. I'm trying to do something simple, like this
but I'm getting the following error
At first I thought this could be an accident of the multiple
SHOTerm
implementations, for example, herehttps://github.com/exoplanet-dev/celerite2/blob/e75dd45ca4f033b22d3ca19d0545a097c7441495/python/celerite2/jax/terms.py#L473-L478
and here
https://github.com/exoplanet-dev/celerite2/blob/e75dd45ca4f033b22d3ca19d0545a097c7441495/python/celerite2/jax/terms.py#L481
but commenting the first one out doesn't solve the problem.
Any ideas? Thanks!