google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
922 stars 64 forks source link

diag(JTJ) can be more efficient #580

Open Joshuaalbert opened 7 months ago

Joshuaalbert commented 7 months ago

In LM method, the max(diag(JTJ)) is used to set the damping factor. As per option 2 in https://github.com/google/jax/issues/19711 it can be made more efficient than currrently implemented. I discovered this when I hit some OOM problems with jaxopt's LM.