google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
939 stars 66 forks source link

Hager-Zhang linesearch does not work with non-jittable functions #448

Open mblondel opened 1 year ago

mblondel commented 1 year ago

Currently, the Hager-Zhang linesearch uses its own lax.while_loop (see this line). This means that the objective functions must be jittable.

In order to support non-jittable functions, we should use JAXopt's while_loop instead, so that we can propagate the jit and unroll options to it.

CC @emilyfertig @srvasude

See #444 for context