Closed HajimeKawahara closed 2 months ago
I ran transit_test.py and couldn't reproduce this error in environment: python==3.10.12
, jaxopt==0.8.3
, jax==0.4.23
, scipy=1.12.0
File "/home/kawahara/anaconda3/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 223, in __getattr__
return self[name]
KeyError: 'njev'
Could this be a scipy issue?
I see, thanks. scipy==1.11.2
Also, scipy==1.14.0
ended up with the same error.
The code worked with scipy==1.14.0
for me. Maybe this is fixed in jaxopt==0.8.3
https://github.com/google/jaxopt/pull/542 so can you try this version? This commit was to fix this error: https://github.com/google/jaxopt/issues/536
@kemasuda Thanks! jaxopt==0.8.3
solves this issue. I will add jaxopt>=0.8.3
in the requirement.
While creating the unit code for the transit module (
unittest_3
branch), I encountered the following error. The same error also occurs at thetf.optimize_transit_params
section when runningexamples/transit.ipynb
.environment:
python==3.10.9
,jaxopt==0.8.2
,jax==0.4.31
unittst_3
branch ::~/jkepler/tests/unittests/transit(unittest_3)>python transit_test.py
optimizing t0 and period...
/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/scipy_wrappers.py:343: OptimizeWarning: Unknown solver options: maxiter res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype), Traceback (most recent call last): File "/home/kawahara/anaconda3/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 223, in getattr return self[name] KeyError: 'njev'
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/home/kawahara/jkepler/tests/unittests/transit/transit_test.py", line 50, in
test_compute_prediction()
File "/home/kawahara/jkepler/tests/unittests/transit/transit_test.py", line 41, in test_compute_prediction
popt = tf.optimize_transit_params(flux, error, t0, period, ecc, omega, b, rstar, rp_over_r, fit_ttvs=False)
File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jkepler-0.0.1-py3.10.egg/jkepler/transit/transitfit.py", line 203, in optimize_transit_params
res = solver.run(p_init, bounds=bounds)
File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py", line 251, in wrapped_solver_fun
return make_custom_vjp_solver_fun(solver_fun, keys)(args, vals)
File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py", line 207, in solver_fun_flat
return solver_fun(*args, *kwargs)
File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/scipy_wrappers.py", line 457, in run
return self._run(init_params, bounds, args, **kwargs)
File "/home/kawahara/anaconda3/lib/python3.10/site-packages/jaxopt/_src/scipy_wrappers.py", line 373, in _run
num_jac_eval=jnp.asarray(res.njev, base.NUM_EVAL_DTYPE),
File "/home/kawahara/anaconda3/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 225, in getattr
raise AttributeError(name) from e
AttributeError: njev. Did you mean: 'nfev'?