google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
936 stars 66 forks source link

Reducing compilation times #514

Open vroulet opened 1 year ago

vroulet commented 1 year ago

The plan of this PR is to reduce the number of times the objective function is compiled. I started by adding a test in common_tests to see how many times a function is compiled for each solver, see below for the results. I will focus first on a simple solver like gradient descent to see if we can reduce the number of compilations. Then I'll extend to each solver until the aforementioned test can be run to satisfy a single compilation of the objective.

AndersonWrapper                       : objective compiled 3 times
BFGS with zoom linesearch             : objective compiled 8 times
ArmijoSGD                             : objective compiled 5 times
GaussNewton                           : objective compiled 26 times
GradientDescent                       : objective compiled 3 times
LBFGS with zoom linesearch            : objective compiled 8 times
LBFGS with hager-zhang linesearch     : objective compiled 75 times
LBFGS with backtracking linesearch    : objective compiled 5 times
LevenbergMarquardt                    : objective compiled 30 times
NonlinearCG with zoom linesearch      : objective compiled 8 times
PolyakSGD                             : objective compiled 3 times
OptaxSolver                           : objective compiled 3 times
AndersonAcceleration                  : objective compiled 3 times
Broyden with backtracking linesearch  : objective compiled 11 times
Bisection                             : objective compiled 4 times
BlockCoordinateDescent                : objective compiled 5 times
LBFGSB with zoom linesearch           : objective compiled 8 times
ProjectedGradient                     : objective compiled 3 times
ProximalGradient                      : objective compiled 3 times
MirrorDescent                         : objective compiled 2 times
BacktrackingLineSearch                : objective compiled 5 times
HagerZhangLineSearch                  : objective compiled 40 times
ZoomLineSearch                        : objective compiled 6 times
vroulet commented 1 year ago

I'm now printing the type of the input:

The results are given below. The current implementation is not as bad as previously claimed: for e.g. the OptaxSolver there is only one compilation (during the update). The call to the objective in the init function (at least in OptaxSolver/PolyakSGD but I think in other solvers too) does not seem to incur additional compilations.

AndersonWrapper                       : objective compiled 2 times
BFGS with zoom linesearch             : objective compiled 7 times
ArmijoSGD                             : objective compiled 3 times
GaussNewton                           : objective compiled 25 times
GradientDescent                       : objective compiled 2 times
LBFGS with zoom linesearch            : objective compiled 7 times
LBFGS with hager-zhang linesearch     : objective compiled 38 times
LBFGS with backtracking linesearch    : objective compiled 3 times
LevenbergMarquardt                    : objective compiled 29 times
NonlinearCG with zoom linesearch      : objective compiled 7 times
PolyakSGD                             : objective compiled 1 times
OptaxSolver                           : objective compiled 1 times
AndersonAcceleration                  : objective compiled 2 times
Broyden with backtracking linesearch  : objective compiled 6 times
Bisection                             : objective compiled 2 times
BlockCoordinateDescent                : objective compiled 5 times
LBFGSB with zoom linesearch           : objective compiled 7 times
ProjectedGradient                     : objective compiled 2 times
ProximalGradient                      : objective compiled 2 times
MirrorDescent                         : objective compiled 2 times
BacktrackingLineSearch                : objective compiled 3 times
HagerZhangLineSearch                  : objective compiled 40 times
ZoomLineSearch                        : objective compiled 6 times