exoplanet-dev / jaxoplanet

Astronomical time series analysis with JAX
https://jax.exoplanet.codes
MIT License
32 stars 11 forks source link

How to jit the computation of jaxoplanet.experimental.starry light curves? #178

Open shashankdholakia opened 5 months ago

shashankdholakia commented 5 months ago

I'm trying to jit the computation of starry light curves so I can compute light curves for different System parameters efficiently.

Currently, if I run the starry.ipynb notebook but modify the last cell to read:

from jaxoplanet.experimental.starry.light_curves import light_curve
from equinox import filter_jit

time = np.linspace(-2.0, 2.0, 1000)
flux = filter_jit(light_curve)(system)(time)

plt.figure(figsize=(12, 4))
_ = plt.plot(flux, c="k")

I get an error:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[1] wrapped in a DynamicJaxprTracer to escape the scope of the transformation. JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.

What's the best way to do this?

dfm commented 5 months ago

(Note: I've never used filter_jit so I don't know if it has different semantics.)

Typically I would include this within a helper function, e.g.:

@jax.jit
def compute_light_curve(params, time):
    system = ...  # build system using params
    return light_curve(system)(time)

which should work fine. This is probably preferred vs jitting with a system as input because there is Python overhead when constructing a system, and it would be good to compile that out. But, if that's truly what you want, something like the following should work:

flux_func = jax.jit(lambda system, time: light_curve(system)(time))
flux = flux_func(system, time)