kemasuda / jkepler

JAX code for modeling Keplerian orbits
MIT License
3 stars 0 forks source link

AttributeError in TransitFit.optimize_transit_params, maybe in jaxopt #6

Closed HajimeKawahara closed 2 months ago

HajimeKawahara commented 2 months ago

While creating the unit code for the transit module (unittest_3 branch), I encountered the following error. The same error also occurs at the tf.optimize_transit_params section when running examples/transit.ipynb.

environment: python==3.10.9, jaxopt==0.8.2, jax==0.4.31

optimizing t0 and period...

/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/scipy_wrappers.py:343: OptimizeWarning: Unknown solver options: maxiter res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype), Traceback (most recent call last): File "/home/kawahara/anaconda3/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 223, in getattr return self[name] KeyError: 'njev'

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/home/kawahara/jkepler/tests/unittests/transit/transit_test.py", line 50, in test_compute_prediction() File "/home/kawahara/jkepler/tests/unittests/transit/transit_test.py", line 41, in test_compute_prediction popt = tf.optimize_transit_params(flux, error, t0, period, ecc, omega, b, rstar, rp_over_r, fit_ttvs=False) File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jkepler-0.0.1-py3.10.egg/jkepler/transit/transitfit.py", line 203, in optimize_transit_params res = solver.run(p_init, bounds=bounds) File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py", line 251, in wrapped_solver_fun return make_custom_vjp_solver_fun(solver_fun, keys)(args, vals) File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py", line 207, in solver_fun_flat return solver_fun(*args, *kwargs) File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/scipy_wrappers.py", line 457, in run return self._run(init_params, bounds, args, **kwargs) File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/scipy_wrappers.py", line 373, in _run num_jac_eval=jnp.asarray(res.njev, base.NUM_EVAL_DTYPE), File "/home/kawahara/anaconda3/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 225, in getattr raise AttributeError(name) from e AttributeError: njev. Did you mean: 'nfev'?

kemasuda commented 2 months ago

I ran transit_test.py and couldn't reproduce this error in environment: python==3.10.12, jaxopt==0.8.3, jax==0.4.23, scipy=1.12.0

  File "/home/kawahara/anaconda3/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 223, in __getattr__
    return self[name]
KeyError: 'njev'

Could this be a scipy issue?

HajimeKawahara commented 2 months ago

I see, thanks. scipy==1.11.2 Also, scipy==1.14.0 ended up with the same error.

kemasuda commented 2 months ago

The code worked with scipy==1.14.0 for me. Maybe this is fixed in jaxopt==0.8.3 https://github.com/google/jaxopt/pull/542 so can you try this version? This commit was to fix this error: https://github.com/google/jaxopt/issues/536

HajimeKawahara commented 2 months ago

@kemasuda Thanks! jaxopt==0.8.3 solves this issue. I will add jaxopt>=0.8.3 in the requirement.