I propose to replace the JAX NumPy operations in LevembergMarquardt with the corresponding ones in tree_utils to address issues #505 and #579. Now, the snippet in issue #505 appears to run correctly, both with and without geodesic acceleration (using the solver solve_cg).
However, QR, LU, and Cholesky still fail since they require the flattened versions of both the Jacobian and parameters.
Regarding the computation of the initial value of the damping_factor, using self.damping_parameter * jnp.max(jtj_diag) requires materializing the full identity matrix. Perhaps, for large problems like the one in Issue #579, it would be useful to include the option for the user to choose an initial damping_factor without calculating jtj_diag? (In the same way of the original paper by Marquardt https://www.jstor.org/stable/2098941, p.438)
Hello @gbruno16,
Good to see you on this repo too!
A few comments:
We are currently in the process to migrate jaxopt into optax so it may be worth thinking about creating such an optimizer in optax directly. I'd be happy to help in this process but be aware that it'll take some time. In particular optax works with gradient transformation and does not handle solvers as in jaxopt yet. But again I thought about it so we could do it together (ping me by mail if you are interested!).
If you want to stick with jaxopt's api, I think you can revamp largely this function or even start from scratch with a simpler implementation (it would be helpful for the migration anyway). By simpler implementation, I mean having an implementation that
never materializes the jacobian, only works with jvps/vjps (note that once you have e.g. the jvp you can use jax.linear_transpose to get the vjp).
uses a cg_sovler that can work directly with linear operators (so no headache of doing an lu etc... and no materialization of the jacobian)
no geodesic accelerration to start with (so the code is a bit simpler, although that geodesic acceleration can be add-on later)
I propose to replace the JAX NumPy operations in
LevembergMarquardt
with the corresponding ones intree_utils
to address issues #505 and #579. Now, the snippet in issue #505 appears to run correctly, both with and without geodesic acceleration (using the solversolve_cg
). However, QR, LU, and Cholesky still fail since they require the flattened versions of both the Jacobian and parameters.Regarding the computation of the initial value of the
damping_factor
, usingself.damping_parameter * jnp.max(jtj_diag)
requires materializing the full identity matrix. Perhaps, for large problems like the one in Issue #579, it would be useful to include the option for the user to choose an initialdamping_factor
without calculatingjtj_diag
? (In the same way of the original paper by Marquardt https://www.jstor.org/stable/2098941, p.438)