exoplanet-dev / jaxoplanet

Astronomical time series analysis with JAX
https://jax.exoplanet.codes
MIT License
41 stars 12 forks source link

LimbDarkLightCurve().light_curve() method fails for multiplanetary systems #112

Closed soichiro-hattori closed 9 months ago

soichiro-hattori commented 9 months ago

The current version of the LimbDarkLightcurve().light_curve() method fails for multiplanetary systems, for example when using the keplerian orbit and passing more than one body to the bodies argument for keplerian.System.

The following code snippet will produce the error.

import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import jpu.numpy as jnpu

from jaxoplanet.orbits import keplerian
from jaxoplanet.light_curves import LimbDarkLightCurve
from jaxoplanet.types import Quantity
from jaxoplanet.units import unit_registry as ureg

jax.config.update("jax_enable_x64", True)

star = keplerian.Central(mass=5, radius=2)
planet_a = keplerian.Body(radius=0.05, period=1)
planet_b = keplerian.Body(radius=0.2, period=2)
system = keplerian.System(central=star, bodies=[planet_a])

t = np.linspace(-5, 5, 500)
lc = LimbDarkLightCurve().light_curve(system, t)

with the error:

{
    "name": "TypeError",
    "message": "Unexpected input type for array: <class 'pint.Quantity'>",
    "stack": "---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 20
     17 system = keplerian.System(central=star, bodies=[planet_a])
     19 t = np.linspace(-5, 5, 500)
---> 20 lc = LimbDarkLightCurve().light_curve(system, t)

File ~/miniforge3/envs/jaxoplanet/lib/python3.11/site-packages/equinox/_module.py:861, in BoundMethod.__call__(self, *args, **kwargs)
    860 def __call__(self, *args, **kwargs):
--> 861     return self.__func__(self.__self__, *args, **kwargs)

File ~/jaxoplanet/src/jaxoplanet/units/decorator.py:173, in QuantityInput.__call__.<locals>.wrapped(*args, **kwargs)
    165     if unit is not None:
    166         bound_args.arguments[name] = jax.tree_util.tree_map(
    167             partial(_apply_units, name=name, strict=self.strict),
    168             value,
    169             unit,
    170             is_leaf=_is_quantity,
    171         )
--> 173 return func(*bound_args.args, **bound_args.kwargs)

File ~/jaxoplanet/src/jaxoplanet/light_curves.py:109, in LimbDarkLightCurve.light_curve(self, orbit, t, texp, oversample, texp_order, limbdark_order)
    107 else:
    108     b /= r_star[..., None]
--> 109     lc = jnp.vectorize(lc_func, signature=\"(k),()->(k)\")(b, r)
    110 lc = jnp.where(z > 0, lc, 0)
    112 # Integrate over exposure time

File ~/miniforge3/envs/jaxoplanet/lib/python3.11/site-packages/jax/_src/numpy/vectorize.py:280, in vectorize.<locals>.wrapped(*args)
    277   excluded_func, args = _apply_excluded(excluded_func, none_args, args)
    278   input_core_dims = [dim for i, dim in enumerate(input_core_dims) if i not in none_args]
--> 280 args = tuple(map(jnp.asarray, args))
    282 broadcast_shape, dim_sizes = _parse_input_dimensions(
    283     args, input_core_dims, error_context)
    285 checked_func = _check_output_dims(
    286     excluded_func, dim_sizes, output_core_dims, error_context)

File ~/miniforge3/envs/jaxoplanet/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2206, in asarray(a, dtype, order)
   2204 if dtype is not None:
   2205   dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)  # type: ignore[assignment]
-> 2206 return array(a, dtype=dtype, copy=False, order=order)

File ~/miniforge3/envs/jaxoplanet/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2169, in array(object, dtype, copy, order, ndmin)
   2166   else:
   2167     return array(np.asarray(view), dtype, copy, ndmin=ndmin)
-> 2169   raise TypeError(f\"Unexpected input type for array: {type(object)}\")
   2171 out_array: Array = lax_internal._convert_element_type(
   2172     out, dtype, weak_type=weak_type)
   2173 if ndmin > ndim(out_array):

TypeError: Unexpected input type for array: <class 'pint.Quantity'>"
}