google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
939 stars 66 forks source link

removed extra computations in error computation for PGD, GD, and Mirror Descent #376

Closed zaccharieramzi closed 1 year ago

zaccharieramzi commented 1 year ago

This PR also gets rid of an extra computation done in the accelerated versions of GD and PGD that was needed to compute an old version of the error.

Added some unit tests to make sure there is no regression on number of calls for GD and PGD