SharmaLlama / ticktack

An open-source carbon box model implementation built on JAX.
https://sharmallama.github.io/ticktack
MIT License
11 stars 3 forks source link

`scipy.optimize.minimize` vs `JAX` tension. #19

Closed Jordan-Dennis closed 2 years ago

Jordan-Dennis commented 2 years ago

Often when I provide a target_C_14 to run I am hit with an error that the ndarray conversion method __array__ was called on a JAX type (full error below). Often I get around this by using the SingleFitter, which seems to have some solution. I'm raising this as an issue because it keeps happening and is very difficult to debug.

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[1])>with<DynamicJaxprTrace(level=0/1)>
While tracing the function run at /home/jordan/anaconda3/envs/kitkat/lib/python3.9/site-packages/ticktack-0.1.3.2-py3.9.egg/ticktack/ticktack.py:473 for jit, this value became a tracer due to JAX operations on these lines:

  operation a:f64[1] = convert_element_type[new_dtype=float64 weak_type=False] b
    from line /home/jordan/anaconda3/envs/kitkat/lib/python3.9/site-packages/ticktack-0.1.3.2-py3.9.egg/ticktack/ticktack.py:433 (_equilibrate_guttler)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Jordan-Dennis commented 2 years ago

It seems to be occurring in the scipy internals (see the following line in the stacktrace)

~/anaconda3/envs/kitkat/lib/python3.9/site-packages/scipy-1.8.0rc2-py3.9-linux-x86_64.egg/scipy/optimize/_minimize.py in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    514 
    515     """
--> 516     x0 = np.atleast_1d(np.asarray(x0))
    517     if x0.dtype.kind in np.typecodes["AllInteger"]:
    518         x0 = np.asarray(x0, dtype=float)
Jordan-Dennis commented 2 years ago

This should be easily resolved by switching to jax.scipy