The current version of the LimbDarkLightcurve().light_curve() method fails for multiplanetary systems, for example when using the keplerian orbit and passing more than one body to the bodies argument for keplerian.System.
The following code snippet will produce the error.
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import jpu.numpy as jnpu
from jaxoplanet.orbits import keplerian
from jaxoplanet.light_curves import LimbDarkLightCurve
from jaxoplanet.types import Quantity
from jaxoplanet.units import unit_registry as ureg
jax.config.update("jax_enable_x64", True)
star = keplerian.Central(mass=5, radius=2)
planet_a = keplerian.Body(radius=0.05, period=1)
planet_b = keplerian.Body(radius=0.2, period=2)
system = keplerian.System(central=star, bodies=[planet_a])
t = np.linspace(-5, 5, 500)
lc = LimbDarkLightCurve().light_curve(system, t)
with the error:
{
"name": "TypeError",
"message": "Unexpected input type for array: <class 'pint.Quantity'>",
"stack": "---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[1], line 20
17 system = keplerian.System(central=star, bodies=[planet_a])
19 t = np.linspace(-5, 5, 500)
---> 20 lc = LimbDarkLightCurve().light_curve(system, t)
File ~/miniforge3/envs/jaxoplanet/lib/python3.11/site-packages/equinox/_module.py:861, in BoundMethod.__call__(self, *args, **kwargs)
860 def __call__(self, *args, **kwargs):
--> 861 return self.__func__(self.__self__, *args, **kwargs)
File ~/jaxoplanet/src/jaxoplanet/units/decorator.py:173, in QuantityInput.__call__.<locals>.wrapped(*args, **kwargs)
165 if unit is not None:
166 bound_args.arguments[name] = jax.tree_util.tree_map(
167 partial(_apply_units, name=name, strict=self.strict),
168 value,
169 unit,
170 is_leaf=_is_quantity,
171 )
--> 173 return func(*bound_args.args, **bound_args.kwargs)
File ~/jaxoplanet/src/jaxoplanet/light_curves.py:109, in LimbDarkLightCurve.light_curve(self, orbit, t, texp, oversample, texp_order, limbdark_order)
107 else:
108 b /= r_star[..., None]
--> 109 lc = jnp.vectorize(lc_func, signature=\"(k),()->(k)\")(b, r)
110 lc = jnp.where(z > 0, lc, 0)
112 # Integrate over exposure time
File ~/miniforge3/envs/jaxoplanet/lib/python3.11/site-packages/jax/_src/numpy/vectorize.py:280, in vectorize.<locals>.wrapped(*args)
277 excluded_func, args = _apply_excluded(excluded_func, none_args, args)
278 input_core_dims = [dim for i, dim in enumerate(input_core_dims) if i not in none_args]
--> 280 args = tuple(map(jnp.asarray, args))
282 broadcast_shape, dim_sizes = _parse_input_dimensions(
283 args, input_core_dims, error_context)
285 checked_func = _check_output_dims(
286 excluded_func, dim_sizes, output_core_dims, error_context)
File ~/miniforge3/envs/jaxoplanet/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2206, in asarray(a, dtype, order)
2204 if dtype is not None:
2205 dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment]
-> 2206 return array(a, dtype=dtype, copy=False, order=order)
File ~/miniforge3/envs/jaxoplanet/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2169, in array(object, dtype, copy, order, ndmin)
2166 else:
2167 return array(np.asarray(view), dtype, copy, ndmin=ndmin)
-> 2169 raise TypeError(f\"Unexpected input type for array: {type(object)}\")
2171 out_array: Array = lax_internal._convert_element_type(
2172 out, dtype, weak_type=weak_type)
2173 if ndmin > ndim(out_array):
TypeError: Unexpected input type for array: <class 'pint.Quantity'>"
}
The current version of the
LimbDarkLightcurve().light_curve()
method fails for multiplanetary systems, for example when using thekeplerian
orbit and passing more than one body to thebodies
argument forkeplerian.System
.The following code snippet will produce the error.
with the error: