google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
939 stars 66 forks source link

Added a new API allowing to warm start the inverse Hessian approximation in LBFGS #354

Closed zaccharieramzi closed 1 year ago

zaccharieramzi commented 1 year ago

This fixes #351 .

@mblondel I couldn't use your suggestion of creating a new type of init LBFGSInit because the init_params variable is used for both init_state and update. Therefore I would have had to add case distinctions in the 2 functions which seemed unreasonable. Rather I took the approach I saw in some other iterative solvers which was to add an extra keyword argument to init_state, update and _value_and_grad_fun.

I added a test to make sure that this runs, but I am not sure whether we need to add a test to make sure that it improves some cases. I also don't know whether we should test that differentiation is ok.

mblondel commented 1 year ago

I couldn't use your suggestion of creating a new type of init LBFGSInit because the init_params variable is used for both init_state and update. Therefore I would have had to add case distinctions in the 2 functions which seemed unreasonable.

I'm not sure I see why that would be an issue. What I propose is that from now on, init_params can be either an array/pytree or a named tuple. Actually, let's just call the new named tuple jaxopt.Initialization, since the Hessian could be useful potentially in other solvers. What's wrong with using if isintance(init_params, jaxopt.Initialization): within init_state and update?

The extra arguments of run are to pass extra arguments to the objective function. Those are parameters w.r.t. which we want to differentiate the solution. We should avoid adding new arguments. I don't think we want to differentiate w.r.t. the Hessian initialization.

zaccharieramzi commented 1 year ago

So actually it has to be also in _value_and_grad_fun: I think the only problem is going to be the repeated code but maybe I can think of a way to deal with that in a DRY way.

In some cases the extra arguments of run are not necessarily passed only to the objective function (like hyperparams_prox for proximal gradient descent), but I agree that usually it's because you want to differentiate w.r.t. them, which is not the case here.

Will revert back to something more along the lines of what you had in mind originally.

zaccharieramzi commented 1 year ago

@mblondel all done: I had to also override the _make_zero_step method, maybe this, in addition to the 2 if statements in the update and _value_and_grad_fun, means that it can all be put in the base IterativeSolver, but I didn't want to create a mess.

mblondel commented 1 year ago

In some cases the extra arguments of run are not necessarily passed only to the objective function

I agree that the handling of prox hyper-parameters is not ideal but the rule is that run and optimality_fun should have the same signature.

zaccharieramzi commented 1 year ago

@mblondel when you mention the narrative doc, do you mean this page? Happy to indeed add some content about L-BFGS, maybe in a subsequent PR to keep things tight, wdyt?

zaccharieramzi commented 1 year ago

all commits squashed

zaccharieramzi commented 1 year ago

@mblondel merged main, re-tested and all green, and squashed commits