Open proteneer opened 5 years ago
Thanks for bringing BFGS support to JAX! Are there also any plans in bringing L-BFGS to JAX?
Yes, there is a plan for L-BFGS. Q1 2021.
If you need something now, you can use tfp.substrates.jax.optimizers.lbfgs_minimize
On Fri, Dec 18, 2020, 6:23 AM Joshua George Albert notifications@github.com wrote:
Yes, there is a plan for L-BFGS. Q1 2021.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1400#issuecomment-748033156, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI24R6YQTQ427J46TBTSVM3UZANCNFSM4I227AGQ .
Hi, has there been any progress on the pure JAX implementation of LBFGS??
Hi, I have recently completed a jittable implementation of L-BFGS for my own purposes.
PR-ing to jax is on my todo-list but keeps getting bumped down... Since there is continued interest, i will prioritise it. Probably this weekend, no promises though.
A few design choices (@Joshuaalbert, @shoyer ?):
I included the following features, because they are convenient for my purposes, but dont match with the current interface of jax.scipy.optimise
:
jax.experimental.host_callback
(this is because my jobs are regularly killed)minimise
I have maximise
, which just feeds the negative of the cost function to minimise
and readjusts this sign in the logged data and outputI can, of course, "downgrade" all of that but maybe they should be part of jax?
I cant publish the code without dealing with license issues first (i copied parts from the BFGS thats already in jax), but i can share it privately if someone wants to have a look.
- The optimisation parameters (inputs to the function to be optimised) can by arbitrary pytrees
Yes, please! We need this for neural nets.
- The optimisation parameters can be complex
Also sounds useful! I guess you just add a few complex conjugates? As long as this matches up with JAX's gradient convention for complex numbers (which I assume it does) this seems like a straightforward extension.
- I have an option to log progress to console or to a file in real time using
jax.experimental.host_callback
(this is because my jobs are regularly killed)
"log progress to console or to a file" sounds a little too opinionated to build into JAX itself, but this is absolutely very important! I would love to see what this looks like, and provide a "recipe" in the JAX docs so users can easily do this themselves.
- in addition to
minimise
I havemaximise
, which just feeds the negative of the cost function tominimise
and readjusts this sign in the logged data and output
I also would leave this part out for JAX. It's totally sensible, but I don't think it's worth adding a second API end-point for all optimizers.
I cant publish the code without dealing with license issues first (i copied parts from the BFGS thats already in jax), but i can share it privately if someone wants to have a look.
The great thing about open source is that you don't need our permission to copy code :). As long as you keep the Apache license & copyright notice on it you are good to go!
Agree with @shoyer. Your code can go in jax/_src/scipy/optimize/lbfgs.py
and tests in, tests/third_party/scipy
assuming you're using tests from scipy. Also, you should add the method to _src/scipy/optimize/minimize.py
. I'll be happy to review this. I was off-and-on working on my own L-BFGS, but love that @Jakob-Unfried has got there first.
Ok, first draft is done, @Joshuaalbert let me know what you think.
A couple of things to talk about:
lambda x: original_fun(x[:dim] + 1.j * x[dim:])
using the existing implementationNone
.
Scipy constructs a LinearOperator
(see here) but i think that would be overkill right? In particular, constructing an entire matrix would defeat the whole point of using L-BFGS (which is why scipy does not do that).PyTree support will be a seperate PR
Thanks for the paper reference. I think you're right about BFGS doing something upon complex input. I'll think about the options and suggest something, in a small PR. Highest precision in line search is fine. Agreed to leaving inverse Hessian estimate as None. I looked over your code and it looks suitable for a PR. I need to go over the equations in more detail, but a first indication will be if it passes tests against scipy's L-BFGS.
if you are looking at the complex formulas, note that i chose the jax convention for all gradient-types variables (g and y).
This means that all g
s and y
s differ from the sorber paper by a complex conjugation
Ho what a great thread !
I was playing with jaxopt.SciBoundedMinimize with "L-BFGS-B" method and faced some problem as this scipy-wraper deals with jnp_to_onp
conversion. So, reading this thread may be is the solution, where to find the how-to-do with jax 0.3.10? Thanks
Is there any interest in adding a quasi-Newton based optimizer? I was thinking of porting over:
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/optimizer/bfgs.py
But wasn't sure if anyone else was interested or had something similar already.