jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.44k stars 2.8k forks source link

BFGS/Quasi-Newton optimizers? #1400

Open proteneer opened 5 years ago

proteneer commented 5 years ago

Is there any interest in adding a quasi-Newton based optimizer? I was thinking of porting over:

https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/optimizer/bfgs.py

But wasn't sure if anyone else was interested or had something similar already.

tetterl commented 3 years ago

Thanks for bringing BFGS support to JAX! Are there also any plans in bringing L-BFGS to JAX?

Joshuaalbert commented 3 years ago

Yes, there is a plan for L-BFGS. Q1 2021.

brianwa84 commented 3 years ago

If you need something now, you can use tfp.substrates.jax.optimizers.lbfgs_minimize

On Fri, Dec 18, 2020, 6:23 AM Joshua George Albert notifications@github.com wrote:

Yes, there is a plan for L-BFGS. Q1 2021.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1400#issuecomment-748033156, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI24R6YQTQ427J46TBTSVM3UZANCNFSM4I227AGQ .

salcats commented 3 years ago

Hi, has there been any progress on the pure JAX implementation of LBFGS??

Jakob-Unfried commented 3 years ago

Hi, I have recently completed a jittable implementation of L-BFGS for my own purposes.

PR-ing to jax is on my todo-list but keeps getting bumped down... Since there is continued interest, i will prioritise it. Probably this weekend, no promises though.

A few design choices (@Joshuaalbert, @shoyer ?): I included the following features, because they are convenient for my purposes, but dont match with the current interface of jax.scipy.optimise:

I can, of course, "downgrade" all of that but maybe they should be part of jax?

I cant publish the code without dealing with license issues first (i copied parts from the BFGS thats already in jax), but i can share it privately if someone wants to have a look.

shoyer commented 3 years ago
  • The optimisation parameters (inputs to the function to be optimised) can by arbitrary pytrees

Yes, please! We need this for neural nets.

  • The optimisation parameters can be complex

Also sounds useful! I guess you just add a few complex conjugates? As long as this matches up with JAX's gradient convention for complex numbers (which I assume it does) this seems like a straightforward extension.

  • I have an option to log progress to console or to a file in real time using jax.experimental.host_callback (this is because my jobs are regularly killed)

"log progress to console or to a file" sounds a little too opinionated to build into JAX itself, but this is absolutely very important! I would love to see what this looks like, and provide a "recipe" in the JAX docs so users can easily do this themselves.

  • in addition to minimise I have maximise, which just feeds the negative of the cost function to minimise and readjusts this sign in the logged data and output

I also would leave this part out for JAX. It's totally sensible, but I don't think it's worth adding a second API end-point for all optimizers.

I cant publish the code without dealing with license issues first (i copied parts from the BFGS thats already in jax), but i can share it privately if someone wants to have a look.

The great thing about open source is that you don't need our permission to copy code :). As long as you keep the Apache license & copyright notice on it you are good to go!

Joshuaalbert commented 3 years ago

Agree with @shoyer. Your code can go in jax/_src/scipy/optimize/lbfgs.py and tests in, tests/third_party/scipy assuming you're using tests from scipy. Also, you should add the method to _src/scipy/optimize/minimize.py. I'll be happy to review this. I was off-and-on working on my own L-BFGS, but love that @Jakob-Unfried has got there first.

Jakob-Unfried commented 3 years ago

Ok, first draft is done, @Joshuaalbert let me know what you think.

A couple of things to talk about:

PyTree support will be a seperate PR

Joshuaalbert commented 3 years ago

Thanks for the paper reference. I think you're right about BFGS doing something upon complex input. I'll think about the options and suggest something, in a small PR. Highest precision in line search is fine. Agreed to leaving inverse Hessian estimate as None. I looked over your code and it looks suitable for a PR. I need to go over the equations in more detail, but a first indication will be if it passes tests against scipy's L-BFGS.

Jakob-Unfried commented 3 years ago

if you are looking at the complex formulas, note that i chose the jax convention for all gradient-types variables (g and y).

This means that all gs and ys differ from the sorber paper by a complex conjugation

jecampagne commented 2 years ago

Ho what a great thread ! I was playing with jaxopt.SciBoundedMinimize with "L-BFGS-B" method and faced some problem as this scipy-wraper deals with jnp_to_onp conversion. So, reading this thread may be is the solution, where to find the how-to-do with jax 0.3.10? Thanks