The LevenbergMarquardt implementation does not accept PyTree parameters, giving TypeError: primal and tangent arguments to jax.jvp must have the same tree structure at levenberg_marquardt.py, line 534.
File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/base.py", line 359, in run
return run(init_params, *args, **kwargs)
File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/base.py", line 301, in _run
state = self.init_state(init_params, *args, **kwargs)
File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 216, in init_state
jtj_diag = self._jtj_diag_op(init_params, *args, **kwargs)
File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 535, in _jtj_diag_op
return jax.vmap(diag_op)(jnp.eye(len(params))).T
File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 534, in <lambda>
diag_op = lambda v: v.T @ self._jtj_op(params, v, *args, **kwargs)
File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 528, in _jtj_op
_, jvp_val = jax.jvp(fun_with_args, (params,), (vec,))
TypeError: primal and tangent arguments to jax.jvp must have the same tree structure; primals have tree structure PyTreeDef((CustomNode(namedtuple[CalibrationParams], [*, *]),)) whereas tangents have tree structure PyTreeDef((*,)).
Description
The
LevenbergMarquardt
implementation does not accept PyTree parameters, givingTypeError: primal and tangent arguments to jax.jvp must have the same tree structure
at levenberg_marquardt.py, line 534.MVCE