google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.1k stars 2.66k forks source link

BFGS/Quasi-Newton optimizers? #1400

Open proteneer opened 4 years ago

proteneer commented 4 years ago

Is there any interest in adding a quasi-Newton based optimizer? I was thinking of porting over:

https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/optimizer/bfgs.py

But wasn't sure if anyone else was interested or had something similar already.

shoyer commented 4 years ago

Yes, I would be really excited about this! In a non-stochastic setting, these are usually much more effective than stochastic gradient descent.

We could even define gradients for these optimizers using implicit root finding (I.e., based on https://github.com/google/jax/pull/1339)

shoyer commented 4 years ago

L-BFGS is particular would be really nice to have: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/optimizer/lbfgs.py

fehiepsi commented 4 years ago

This would be great! I am interested in and happy to review the math of your PR. (another reference is pytorch lbfgs which is mostly based on https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html).

proteneer commented 4 years ago

Whoa - differentiable root finding is pretty whack! The PyTorch L-BFGS seems significantly more advanced than the tfp one - let me read them over to see if we have everything that's needed.

fehiepsi commented 4 years ago

I think the difference between two versions is: pytorch uses strong Wolfe line search (not more advanced) while tensorflow uses Hager Zhang line search. I would propose to stick with Hager Zhang line search if you are familiar with tensorflow code. :)

shoyer commented 4 years ago

It looks like scipy also use strong Wolfe line search. I'm sure this makes a small difference, but it's all at the margin -- I'd be happy to have either in JAX.

On Fri, Sep 27, 2019 at 7:58 AM Du Phan notifications@github.com wrote:

I think the difference between two versions is: pytorch uses strong Wolfe line search while tensorflow uses Hager Zhang line search. It seems to me (without evidence) that Hager Zhang line search is better.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1400?email_source=notifications&email_token=AAJJFVX56CLTCDQQZO465LTQLYNS7A5CNFSM4I227AG2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD7ZFPNQ#issuecomment-535975862, or mute the thread https://github.com/notifications/unsubscribe-auth/AAJJFVTDORFT55V2LF4LXFDQLYNS7ANCNFSM4I227AGQ .

fehiepsi commented 4 years ago

When I used the strong Wolfe line search to do bayesian optimization for some toy datasets, I faced instability issues (hessian blow up -> nan happens) near the optimum value. So I hope that Hager Zhang line search will do better. Looking at the documentation of Hager Zhang, it seems to deal with the issue I faced: "On a finite precision machine, the exact Wolfe conditions can be difficult to satisfy when one is very close to the minimum and as argued by [Hager and Zhang (2005)][1], one can only expect the minimum to be determined within square root of machine precision".

TuanNguyen27 commented 4 years ago

@proteneer Just wondering if you are still working on this! I'm also interested in attempting an implementation.

proteneer commented 4 years ago

Hi @TuanNguyen27 - sorry this fell off the radar. I ended up implementing the FIRE optimizer instead based on @sschoenholz 's code. Feel free to take a stab yourself, as I'm not working on this right now.

TuanNguyen27 commented 4 years ago

@shoyer @fehiepsi I've been staring at pytorch's implementation, but it contains too many parameters (both initial and auxiliary) to pack into state variable if i'm trying to follow init_fun, update_fun, get_params template defined in optimizers.py. Is it something I need to strictly follow, or do you have any design advice / suggestion that could make this more approachable? :D

mattjj commented 4 years ago

optimizers.py is specialized to stochastic first-order optimizers; for second-order optimizers, don't feel constrained to follow its APIs. IMO just do what makes sense to you (while hopefully following the spirit of being functional and minimal).

fehiepsi commented 4 years ago

@TuanNguyen27 PyTorch LBFGS is implemented to have a similar interface as other stochastic optimizers. In practice, I only use 1 "step" of PyTorch LBFGS to find the solution, so I guess you don't need to follow other JAX optimizers, and no need to manage state variable except for internal while loops. In my view, the initial API in JAX would be close to scipy lbfgs-b, that is:

def lbfgs(fun, x0, maxtol=..., gtol=..., maxfun=..., maxiter=...):
    return solution

or similar to build_ode

def lbfgs(fun, maxtol=..., gtol=..., maxfun=..., maxiter=...):
    return optimize_fun

solution = lbfgs(f)(x0)

. But as Matt said, just do what makes sense to you. FYI, wikipedia has a nice list of test functions to verify your implementation.

shoyer commented 4 years ago

For JAX, I think it makes sense to target an interface like SciPy here.

On Sat, Nov 16, 2019 at 6:57 PM Du Phan notifications@github.com wrote:

@TuanNguyen27 https://github.com/TuanNguyen27 PyTorch LBFGS is implemented to have a similar interface as other stochastic optimizers. In practice, I only use 1 "step" of PyTorch LBFGS to find the solution, so I guess you don't need to follow other optimizer APIs, and no need to manage state variable except for internal while loops. In my view, the initial API in JAX would be close to scipy lbfgs-b https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html#optimize-minimize-lbfgsb, that is:

def lbfgs(fun, x0, maxtol=..., gtol=..., maxfun=..., maxiter=...): return solution

or similar to build_ode https://github.com/google/jax/blob/master/jax/experimental/ode.py#L377

def lbfgs(fun, maxtol=..., gtol=..., maxfun=..., maxiter=...): return optimize_fun

solution = lbfgs(f)(x0)

. But as Matt said, just do what makes sense to you. FYI, wikipedia has a nice list of test functions https://en.wikipedia.org/wiki/Test_functions_for_optimization to verify your implementation.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1400?email_source=notifications&email_token=AAJJFVSFMQKJHVDOC3QDQGDQUCXIRA5CNFSM4I227AG2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEEIAICI#issuecomment-554697737, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVSFSXTL3LJM733L4MDQUCXIRANCNFSM4I227AGQ .

Jakob-Unfried commented 4 years ago

Any news on this?

I wrote an adaption of pytorchs L-BFGS implementation for my personal use (even supporting complex arguments) that works well for me. I am happy to share it, but I am new at contributing to large projects and would need some pointers, where to start, how to write tests, etc (@shoyer ?)

joglekara commented 4 years ago

I'm also more than happy to help here @shoyer and @Jakob-Unfried

shoyer commented 4 years ago

It would be great for someone to start on this! Ideally it would be broken up into a number of smaller PRs, e.g., starting with line search, then adding L-BFGS, then making it differentiable via the implicit function theorem.

I think the TF-probability implementation is probably the best place to start for a JAX port. The TF control flow ops are all functional and have close equivalents in JAX. The pytorch version looks fine but it’s very object oriented/stateful, which could make it harder to integrate into JAX.

Jakob-Unfried commented 4 years ago

I have an implementation of L-BFGS with Strong Wolfe line search lying around. Just needs a few touch ups to make it nicer as well as documentation that people who are not me will understand.

I will get it ready for someone to review, probably late this evening (european time)

A couple of questions:

  1. There are quite a lot of parameters (11 if i didnt miscount) in the algorithm and in most cases people will just use the defaults. My idea would be to make them **kwargs. Thoughts?

  2. I wanted to enable the cost_function to take an arbitrary pytree as input. Since they are a finite number of independent variables, I think of them as a single column vector for the maths of optimisation. During the algorithm, I need to perform vector operations (addition, scalar product, etc) on these "vectors". I implemented a lot of helper functions, like e.g. _vector_add(x1, x2) which essentially just use tree utils to perform the necessary steps. in this example return tree_multimap(lambda arr1, arr2: arr1 + arr2, x1, x2). I use a total of 12 such functions, so it might unnecessarily clutter the code. The alternative I see is to implement the algorithm for cost_functions witch can only take a single 1D array as input and use a decorator to enable arbitrary pytrees. Thoughts=

  3. Where do I write tests? Is there a conventional format to follow? Maybe a good example to look at?

  4. do you prefer the semantics x_optim = lbfgs(fun, x0) or x_optim = lbfgs(fun)(x0) ? The latter would probably make more sense if we want to think about it as implicitly defining a new function. But in the end, the latter can just be a thin wrapper around the former.

shoyer commented 4 years ago
  1. To the extent it makes sense, I would suggest sticking to the general interface of scipy.optimize.minimize. This suggests stuffing everything into options. We can drop arguments that don't really make sense for JAX, e.g., could use jax.grad or jax.value_and_grad internally rather than requiring that it be passed explicitly with jac=True.

  2. Handling pytrees sounds great! Helper functions are fine, though if you can group a series of operations and only applying tree_multimap once that's even better. Note that you can use it as a decorator with curry:

    
    In [6]: from jax.util import curry

In [7]: from jax.tree_util import tree_multimap

In [8]: @curry(tree_multimap) ...: def f(x, y): ...: return x + y ...:

In [9]: f({'a': 1}, {'a': 2}) Out[9]: {'a': 3}



3. Take a look at the existing tests in `tests/`. We run tests with pytest with classes inherited from `JaxTestCase` which handles parameterization. `vectorize_test.py` might be a good minimize example to look at.

4. I would lean towards the former, which matches SciPy.
Jakob-Unfried commented 4 years ago

@joglekara @shoyer

Let me know what you think about my implementation: https://github.com/Jakob-Unfried/jax/commit/b28300cbcc234cda0b40b16cf15678e3f78e4085

Open Questions:

shoyer commented 4 years ago

@Jakob-Unfried it looks like a great start, but ideally we would use JAX's functional control flow so it is compatible with jit. There is lots of scalar arithmetic and book-keeping that XLA could likely accelerate nicely. Note: the implementation doesn't need to be differentiable, since we can define the gradient pass as a separate root finding problem via lax.custom_root.

Jakob-Unfried commented 4 years ago

Unfortunately, I am not (yet) very familiar with the requirements of jit and how to write code that it can process efficiently.

I see two paths forward:

david-waterworth commented 4 years ago

I'd also love to see an implementation of Levenberg-Marquardt. The Matlab ML toolkit includes an implementation and it's regularly used to train for timeseries (regression) RNN (NARX) models - it seems to perform much better than gradient descent. I'd love to be able to validate this and compare with other 2nd order solvers.

I'll watch the progress of this thread with interest and try and contribute where I can and if/when I feel I understand enough I might give it a go.

shoyer commented 4 years ago

@Jakob-Unfried Let me suggest an iterative process:

  1. Check in a working version of your current code (with tests) without changing the control flow form Python.
  2. Convert functions one by one into a form suitable for use with jit, starting with the inner-most layers and working outwards. We can verify that this works by adding explicit @jit annotations.
  3. Make the whole thing differentiable, via the implicit theorem.

This can and should happen over multiple PRs. As long you or others express willingness to help, I am happy to start on the process.

As for "how to write code that works well with jit", the general guidance is:

  1. Convert loops/recursion into lax.while_loop. Yes, this typically means you need to pack a lot of extra state into the val.
  2. Convert if statements into np.where (usually it isn't worth using lax.cond).
TuanNguyen27 commented 4 years ago

Also happy to help !

Jakob-Unfried commented 4 years ago

@shoyer Sounds good! I will write (and run) a full set of tests and PR the version with pure-python control-flow later today.

I will then get back here with a TODO list.

@TuanNguyen27 Let's chat on Gitter?

TuanNguyen27 commented 4 years ago

@Jakob-Unfried sounds good, I just sent you a gitter DM to follow up

HDembinski commented 4 years ago

What is the benefit of adding a minimizer? The minimizer calls the cost function and its gradient. Both can be JAX accelerated already and the functions dominate the computing time. I don't see how one could gain here. The power of JAX is that you can mix with other libraries.

shoyer commented 4 years ago

You absolutely can and should leverage optimizers from ScPy for many applications. I'm interested in optimizers written in JAX for use cases that involve nested optimization, like meta-optimization.

On Fri, Feb 21, 2020 at 10:31 PM Hans Dembinski notifications@github.com wrote:

What is the benefit of adding a minimizer? The minimizer calls the cost function and its gradient. Both can be JAX accelerated already and the functions dominate the computing time. I don't see how one could gain here. The power of JAX is that you can mix with other libraries.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1400?email_source=notifications&email_token=AAJJFVQJNF5TSK5ZPYXIA5DREDBDDA5CNFSM4I227AG2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEMUYSEI#issuecomment-589924625, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVU3BAFMHLSSTCN7BDLREDBDDANCNFSM4I227AGQ .

HDembinski commented 4 years ago

Ok, suppose that makes sense, nevermind then.

phinate commented 4 years ago

@Jakob-Unfried it looks like a great start, but ideally we would use JAX's functional control flow so it is compatible with jit. There is lots of scalar arithmetic and book-keeping that XLA could likely accelerate nicely. Note: the implementation doesn't need to be differentiable, since we can define the gradient pass as a separate root finding problem via lax.custom_root.

Hi @shoyer, would you be able to elaborate how you would utilize lax.custom_root in practice to perform the gradient pass of an existing SciPy optimizer, e.g. scipy.minimize? Is that possible?

(I tried this with an example from the SciPy docs following the jax docs, but I wasn't able to get output due to the incompatibility of tracing gradient through numpy functions.)

shoyer commented 4 years ago

It should be possible to wrap a SciPy optimizer to make it differentiable, but I haven't figured out exactly how yet. I think it would require making a new JAX primitive to give the SciPy function an abstract eval rule first.

On Thu, Feb 27, 2020 at 8:03 AM Nathan Simpson notifications@github.com wrote:

@Jakob-Unfried https://github.com/Jakob-Unfried it looks like a great start, but ideally we would use JAX's functional control flow so it is compatible with jit. There is lots of scalar arithmetic and book-keeping that XLA could likely accelerate nicely. Note: the implementation doesn't need to be differentiable, since we can define the gradient pass as a separate root finding problem via lax.custom_root.

Hi @shoyer https://github.com/shoyer, would you be able to elaborate how you would utilize lax.custom_root in practice to perform the gradient pass of an existing SciPy optimizer, e.g. scipy.minimize? Is that possible?

(I tried this with an example from the SciPy docs https://scipy.github.io/devdocs/generated/scipy.optimize.minimize.html#scipy.optimize.minimize following the jax docs https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.custom_root.html#jax.lax.custom_root, but I wasn't able to get output due to the incompatibility of tracing gradient through numpy functions.)

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1400?email_source=notifications&email_token=AAJJFVUHFFNHILCE2PUX5H3RE7P35A5CNFSM4I227AG2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOENE4MSI#issuecomment-592037449, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVXIBRH6CYR5JVLJJTDRE7P35ANCNFSM4I227AGQ .

Joshuaalbert commented 4 years ago

Hi @shoyer and others. I made a PR for a jittable implementation of BFGS . See here #3101.

nrontsis commented 4 years ago

I am confused about the benefits of implementing BFGS or any other non-trivial optimisation solvers in JAX.

If it's "autodiff" what we want, then this is addressed by the so-called "sensitivity analysis" in optimisation. A software tools that follows this approach is sIPOPT (see this paper for more details). Note that tools like IPOPT and sIPOPT allow for the solution (and gradients) of generic, constrained, problems of the form:

minimize                f(x)
subject to       l_i <= g_i(x) <= u_i

In contrast, the BFGS algorithm implemented in #3101 only supports for unconstrained problems. This is a step back compared to L-BFGS-B of SciPy that is memory efficient and allows for bound constraints in the variables.

Moreover, I would guess that automatic differentiation on pure JAX solvers like the one in #3101 might be considerably slower than the approach of sIPOPT.

I apologise for the non-constructive comment - I believe that Jax can be an amazing tool for optimization, but rewriting elaborate solvers in pure JAX does not seem useful to me - except for research purposes. In my opinion,

sound like the features most needed by the optimisation community.

shoyer commented 4 years ago

@nrontsis Support for constrained optimization problems would certainly be nice to have as well, but there are plenty of use cases for unconstrained optimization, too. One good example that is relevant to JAX's users is optimizing the weights of a neural network. The number of 👍 on this issue and the fact that L-BFGS exists in PyTorch are both good indications that it would be a welcome feature.

The BFGS PR from #3101 is only a first step. We plan to extend it both to support both the limited memory version (essential for large problems) and to support implicit differentiation, i.e., using the implicit function theorem (I assume this is what sIPOPT does).

If you're interested in collaborating on better optimization tools in JAX (including integration with external tools) we'd love to work with you. I would suggest opening new issues for specific suggestions. Quasi-Newton methods (this issue) is only one piece of the puzzle.

nrontsis commented 4 years ago

Thanks for the reply @shoyer. I have only started using JAX recently and I would love to contribute in the future wherever I can.

The number of 👍 on this issue and the fact that L-BFGS exists in PyTorch are both good indications that it would be a welcome feature.

I felt equally puzzled when I saw that PyTorch and TensorFlow have L-BFGS implemented (and -B variant tried unfortunately without getting merged, at least for now). What are the advantages of a native implementation over a thin-wrapper of SciPy's L-BFGS-B? You mention meta-optimization as a use case in an earlier comment - I would be interested to hear what advantages you see in this case (especially if the differentiation is done implicitly).

Joshuaalbert commented 4 years ago

@nrontsis One advantage is speed when you want to embed an optimisation inside of a larger algorithm (in line with one of JAX's aims, high performance). In one of my use cases, I have an algorithm where one component of it requires doing unconstrained minimisation. When I use the embedded JAX written BFGS and write the whole algorithm in JAX and then compile it, it runs orders of magnitude faster than breaking out of XLA to use scipy with a wrapper. It's the difference between waiting days versus minutes in my case.

shoyer commented 4 years ago

The main reason to rewrite algorithms in JAX is to support JAX's function transformations. In this case, that would mostly be relevant for performance:

nrontsis commented 4 years ago

I see, thank you for the explanation!

brianwa84 commented 3 years ago

TFP-on-JAX now supports L-BFGS:

!pip install -q tfp-nightly[jax]
from tensorflow_probability.substrates import jax as tfp
from jax import numpy as jnp

def rosenbrock(coord):
  """The Rosenbrock function in two dimensions with a=1, b=100.

  Args:
    coord: Array with shape [2]. The coordinate of the point to evaluate
      the function at.

  Returns:
    fv: A scalar tensor containing the value of the Rosenbrock function at
      the supplied point.
    dcoord: Array with shape [2]. The derivative of the function with respect to 
      `coord`.
  """
  x, y = coord[0], coord[1]
  fv = (1 - x)**2 + 100 * (y - x**2)**2
  dfx = 2 * (x - 1) + 400 * x * (x**2 - y)
  dfy = 200 * (y - x**2)
  return fv, jnp.stack([dfx, dfy])

start = jnp.float32([-1.2, 1.0])
results = tfp.optimizer.lbfgs_minimize(
    rosenbrock, initial_position=start, tolerance=1e-5)

results.position  # => DeviceArray([1.0000001, 1.0000002], dtype=float32)
brianwa84 commented 3 years ago

Also, in tomorrow's nightly build, BFGS and Nelder-Mead optimizers will be supported in TFP JAX.

shoyer commented 3 years ago

Now that BFGS support has been merged and provides some general scaffolding (see #3101), we'd love to get L-BFGS (and other optimizers) in JAX proper, as well as support for pytrees and implicit differentiation.

Joshuaalbert commented 3 years ago

Thanks @shoyer for your hard work getting this in. I'll be happy to add some more. I'm interested in getting (L)BFGS in with bounded constraints. As well I can add pytree support to BFGS, and as promised in my code comment in line search, I still intend to profile linesearch between the versions using where's and cond's. Op 30 jul. 2020 3:24 a.m. schreef Stephan Hoyer notifications@github.com: Now that BFGS support has been merged and provides some general scaffolding (see #3101), we'd love to get L-BFGS (and other optimizers) in JAX proper, as well as support for pytrees and implicit differentiation.

—You are receiving this because you commented.Reply to this email directly, view it on GitHub, or unsubscribe.

KeunwooPark commented 3 years ago

Hi. I'm interested in mixing an adam optimizer and lbfgs optimizer. Jax provides an adam optimizer, so I used that. But I don't understand how I can turn the network parameters from Jax's adam optimizer to the input of tfp.optimizer.lbfgs_minimize().

The below code conceptually shows what I want to do. The code tries to optimize a network with adam first, and then use lbfgs.

from tensorflow_probability.substrates import jax as tfp
from jax import numpy as jnp
import numpy as np
from jax.experimental import stax, optimizers
import jax
from jax import grad

def create_mlp(num_channels = []):
    modules = []
    for nc in num_channels:
        modules.append(stax.Dense(nc))
        modules.append(stax.Softplus)

    modules.append(stax.Dense(1))
    return stax.serial(*modules)

def main():
    # 1. create a network
    net_init_random, net_apply = create_mlp([10]*3)
    rng = jax.random.PRNGKey(0)
    in_shape = (-1, 2)

    # 2. create a gradient decent optimizer
    out_shape, net_params = net_init_random(rng, in_shape)
    opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)

    def loss(params, x, y):
        return jnp.mean((net_apply(params, x) - y)**2)

    def step(i, opt_state, x, y):
        p = get_params(opt_state)
        g = grad(loss)(p, x, y)
        return opt_update(i, g, opt_state)

    opt_state = opt_init(net_params)

    # 3. optimize
    for i in range(100):
        x = np.random.random_sample((10,2))
        y = np.random.random_sample((10,1))
        step(i, opt_state, x, y)

    # 4. lbfgs optimization
    _x = np.random.random_sample((10,2))
    _y = np.random.random_sample((10,1))

    def func(params):
        return loss(params, _x, _y)

    net_params = get_params(opt_state)
    results = tfp.optimizer.lbfgs_minimize(
        func, initial_position=net_params, tolerance=1e-5)

if __name__ == "__main__":
    main()

Any kind of advice would be very helpful to me. @brianwa84 Could you provide an example that mixes the two kinds of optimizers?

Joshuaalbert commented 3 years ago

Re: @nrontsis https://github.com/google/jax/issues/1400#issuecomment-648453239.

I am confused about the benefits of implementing BFGS or any other non-trivial optimisation solvers in JAX.

I made a comparison of BFGS against scipy+numpy, scipy+jax, and pure jax on a benchmark problem, N-d least squares + L1 regularisation. Screenshot from 2020-10-05 16-47-31

@KeunwooPark I can't speak to using TFP in combination, but maybe https://github.com/google/jax/issues/3847 is interesting to you. If you wanted to use pure JAX L-BFGS then you'll need to wait. I plan on implementing it mid-November, as well as pytree arguments.

nrontsis commented 3 years ago

That's quite impressive! Can you share the script used to generate it?

Joshuaalbert commented 3 years ago

Just put it up here: https://gist.github.com/Joshuaalbert/214f14bbdd55d413693b8b413a384cae

EDIT: the scipy+numpy and scipy+jitted(func) do numerical jacobians which requires many more function evaluations. Best to compare pure JAX and scipy+jitted(fund and grad)

KeunwooPark commented 3 years ago

@Joshuaalbert Thank you for your comment! I guess I have to think more about a workaround or just use TensorFlow. Or use your L-BFGS implementation. According to your graph, mixing scipy and jax doesn't seem to be a good idea. I'm interested in implementing L-BFGS, but I'm really new to these concepts and still learning.

Joshuaalbert commented 3 years ago

@KeunwooPark If you're looking for something immediately, scipy+jax will probably do the job. Unless the speed is crucial.

sharadmv commented 3 years ago

@KeunwooPark I modified your example to feed the neural network weights into L-BFGS. Unfortunately, tfp.optimize.lbfgs_minimize does not support optimizing over structures of tensors/arrays, but you can concatenate the network weights together, and then split them back up in the loss function.

from tensorflow_probability.substrates import jax as tfp
from jax import numpy as jnp
import numpy as np
from jax.experimental import stax, optimizers
import jax
from jax import grad

def create_mlp(num_channels = []):
    modules = []
    for nc in num_channels:
        modules.append(stax.Dense(nc))
        modules.append(stax.Softplus)

    modules.append(stax.Dense(1))
    return stax.serial(*modules)

def main():
    # 1. create a network
    net_init_random, net_apply = create_mlp([10]*3)
    rng = jax.random.PRNGKey(0)
    in_shape = (-1, 2)

    # 2. create a gradient decent optimizer
    out_shape, net_params = net_init_random(rng, in_shape)
    opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)

    def loss(params, x, y):
        return jnp.mean((net_apply(params, x) - y)**2)

    def step(i, opt_state, x, y):
        p = get_params(opt_state)
        g = grad(loss)(p, x, y)
        return opt_update(i, g, opt_state)

    opt_state = opt_init(net_params)

    # 3. optimize
    for i in range(100):
        x = np.random.random_sample((10,2))
        y = np.random.random_sample((10,1))
        step(i, opt_state, x, y)

    # 4. lbfgs optimization
    _x = np.random.random_sample((10,2))
    _y = np.random.random_sample((10,1))

    net_params = get_params(opt_state)

    def concat_params(params):
        flat_params, params_tree = jax.tree_util.tree_flatten(params)
        params_shape = [x.shape for x in flat_params]
        return jnp.concatenate([x.reshape(-1) for x in flat_params]), (params_tree, params_shape)

    param_vector, (params_tree, params_shape) = concat_params(net_params)

    @jax.value_and_grad
    def func(param_vector):
        split_params = jnp.split(param_vector,
                np.cumsum([np.prod(s) for s in params_shape[:-1]]))
        flat_params = [x.reshape(s) for x, s in zip(split_params, params_shape)]
        params = jax.tree_util.tree_unflatten(params_tree, flat_params)
        return loss(params, _x, _y)

    results = tfp.optimizer.lbfgs_minimize(
        jax.jit(func), initial_position=param_vector, tolerance=1e-5)

if __name__ == "__main__":
    main()

Hope this is helpful! In the long run, better support for these structured inputs to L-BFGS in TFP will likely be the path forward.

KeunwooPark commented 3 years ago

@sharadmv Wow. Thank you very much! This helps me a lot.