Open proteneer opened 5 years ago
Yes, I would be really excited about this! In a non-stochastic setting, these are usually much more effective than stochastic gradient descent.
We could even define gradients for these optimizers using implicit root finding (I.e., based on https://github.com/google/jax/pull/1339)
L-BFGS is particular would be really nice to have: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/optimizer/lbfgs.py
This would be great! I am interested in and happy to review the math of your PR. (another reference is pytorch lbfgs which is mostly based on https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html).
Whoa - differentiable root finding is pretty whack! The PyTorch L-BFGS seems significantly more advanced than the tfp one - let me read them over to see if we have everything that's needed.
I think the difference between two versions is: pytorch uses strong Wolfe line search (not more advanced) while tensorflow uses Hager Zhang line search. I would propose to stick with Hager Zhang line search if you are familiar with tensorflow code. :)
It looks like scipy also use strong Wolfe line search. I'm sure this makes a small difference, but it's all at the margin -- I'd be happy to have either in JAX.
On Fri, Sep 27, 2019 at 7:58 AM Du Phan notifications@github.com wrote:
I think the difference between two versions is: pytorch uses strong Wolfe line search while tensorflow uses Hager Zhang line search. It seems to me (without evidence) that Hager Zhang line search is better.
— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1400?email_source=notifications&email_token=AAJJFVX56CLTCDQQZO465LTQLYNS7A5CNFSM4I227AG2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD7ZFPNQ#issuecomment-535975862, or mute the thread https://github.com/notifications/unsubscribe-auth/AAJJFVTDORFT55V2LF4LXFDQLYNS7ANCNFSM4I227AGQ .
When I used the strong Wolfe line search to do bayesian optimization for some toy datasets, I faced instability issues (hessian blow up -> nan happens) near the optimum value. So I hope that Hager Zhang line search will do better. Looking at the documentation of Hager Zhang, it seems to deal with the issue I faced: "On a finite precision machine, the exact Wolfe conditions can be difficult to satisfy when one is very close to the minimum and as argued by [Hager and Zhang (2005)][1], one can only expect the minimum to be determined within square root of machine precision".
@proteneer Just wondering if you are still working on this! I'm also interested in attempting an implementation.
Hi @TuanNguyen27 - sorry this fell off the radar. I ended up implementing the FIRE optimizer instead based on @sschoenholz 's code. Feel free to take a stab yourself, as I'm not working on this right now.
@shoyer @fehiepsi I've been staring at pytorch's implementation, but it contains too many parameters (both initial and auxiliary) to pack into state
variable if i'm trying to follow init_fun, update_fun, get_params
template defined in optimizers.py
. Is it something I need to strictly follow, or do you have any design advice / suggestion that could make this more approachable? :D
optimizers.py is specialized to stochastic first-order optimizers; for second-order optimizers, don't feel constrained to follow its APIs. IMO just do what makes sense to you (while hopefully following the spirit of being functional and minimal).
@TuanNguyen27 PyTorch LBFGS is implemented to have a similar interface as other stochastic optimizers. In practice, I only use 1 "step" of PyTorch LBFGS to find the solution, so I guess you don't need to follow other JAX optimizers, and no need to manage state
variable except for internal while loops. In my view, the initial API in JAX would be close to scipy lbfgs-b, that is:
def lbfgs(fun, x0, maxtol=..., gtol=..., maxfun=..., maxiter=...):
return solution
or similar to build_ode
def lbfgs(fun, maxtol=..., gtol=..., maxfun=..., maxiter=...):
return optimize_fun
solution = lbfgs(f)(x0)
. But as Matt said, just do what makes sense to you. FYI, wikipedia has a nice list of test functions to verify your implementation.
For JAX, I think it makes sense to target an interface like SciPy here.
On Sat, Nov 16, 2019 at 6:57 PM Du Phan notifications@github.com wrote:
@TuanNguyen27 https://github.com/TuanNguyen27 PyTorch LBFGS is implemented to have a similar interface as other stochastic optimizers. In practice, I only use 1 "step" of PyTorch LBFGS to find the solution, so I guess you don't need to follow other optimizer APIs, and no need to manage state variable except for internal while loops. In my view, the initial API in JAX would be close to scipy lbfgs-b https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html#optimize-minimize-lbfgsb, that is:
def lbfgs(fun, x0, maxtol=..., gtol=..., maxfun=..., maxiter=...): return solution
or similar to build_ode https://github.com/google/jax/blob/master/jax/experimental/ode.py#L377
def lbfgs(fun, maxtol=..., gtol=..., maxfun=..., maxiter=...): return optimize_fun
solution = lbfgs(f)(x0)
. But as Matt said, just do what makes sense to you. FYI, wikipedia has a nice list of test functions https://en.wikipedia.org/wiki/Test_functions_for_optimization to verify your implementation.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1400?email_source=notifications&email_token=AAJJFVSFMQKJHVDOC3QDQGDQUCXIRA5CNFSM4I227AG2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEEIAICI#issuecomment-554697737, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVSFSXTL3LJM733L4MDQUCXIRANCNFSM4I227AGQ .
Any news on this?
I wrote an adaption of pytorchs L-BFGS implementation for my personal use (even supporting complex arguments) that works well for me. I am happy to share it, but I am new at contributing to large projects and would need some pointers, where to start, how to write tests, etc (@shoyer ?)
I'm also more than happy to help here @shoyer and @Jakob-Unfried
It would be great for someone to start on this! Ideally it would be broken up into a number of smaller PRs, e.g., starting with line search, then adding L-BFGS, then making it differentiable via the implicit function theorem.
I think the TF-probability implementation is probably the best place to start for a JAX port. The TF control flow ops are all functional and have close equivalents in JAX. The pytorch version looks fine but it’s very object oriented/stateful, which could make it harder to integrate into JAX.
I have an implementation of L-BFGS with Strong Wolfe line search lying around. Just needs a few touch ups to make it nicer as well as documentation that people who are not me will understand.
I will get it ready for someone to review, probably late this evening (european time)
A couple of questions:
There are quite a lot of parameters (11 if i didnt miscount) in the algorithm and in most cases people will just use the defaults. My idea would be to make them **kwargs
. Thoughts?
I wanted to enable the cost_function to take an arbitrary pytree as input. Since they are a finite number of independent variables, I think of them as a single column vector for the maths of optimisation. During the algorithm, I need to perform vector operations (addition, scalar product, etc) on these "vectors". I implemented a lot of helper functions, like e.g. _vector_add(x1, x2)
which essentially just use tree utils to perform the necessary steps. in this example return tree_multimap(lambda arr1, arr2: arr1 + arr2, x1, x2)
. I use a total of 12 such functions, so it might unnecessarily clutter the code.
The alternative I see is to implement the algorithm for cost_functions witch can only take a single 1D array as input and use a decorator to enable arbitrary pytrees. Thoughts=
Where do I write tests? Is there a conventional format to follow? Maybe a good example to look at?
do you prefer the semantics x_optim = lbfgs(fun, x0)
or x_optim = lbfgs(fun)(x0)
?
The latter would probably make more sense if we want to think about it as implicitly defining a new function. But in the end, the latter can just be a thin wrapper around the former.
To the extent it makes sense, I would suggest sticking to the general interface of scipy.optimize.minimize
. This suggests stuffing everything into options
. We can drop arguments that don't really make sense for JAX, e.g., could use jax.grad
or jax.value_and_grad
internally rather than requiring that it be passed explicitly with jac=True
.
Handling pytrees sounds great! Helper functions are fine, though if you can group a series of operations and only applying tree_multimap
once that's even better. Note that you can use it as a decorator with curry
:
In [6]: from jax.util import curry
In [7]: from jax.tree_util import tree_multimap
In [8]: @curry(tree_multimap) ...: def f(x, y): ...: return x + y ...:
In [9]: f({'a': 1}, {'a': 2}) Out[9]: {'a': 3}
3. Take a look at the existing tests in `tests/`. We run tests with pytest with classes inherited from `JaxTestCase` which handles parameterization. `vectorize_test.py` might be a good minimize example to look at.
4. I would lean towards the former, which matches SciPy.
@joglekara @shoyer
Let me know what you think about my implementation: https://github.com/Jakob-Unfried/jax/commit/b28300cbcc234cda0b40b16cf15678e3f78e4085
Open Questions:
@Jakob-Unfried it looks like a great start, but ideally we would use JAX's functional control flow so it is compatible with jit
. There is lots of scalar arithmetic and book-keeping that XLA could likely accelerate nicely. Note: the implementation doesn't need to be differentiable, since we can define the gradient pass as a separate root finding problem via lax.custom_root
.
Unfortunately, I am not (yet) very familiar with the requirements of jit
and how to write code that it can process efficiently.
I see two paths forward:
Someone gives me pointers, what needs to be done
(e.g. "Instead of for
loops i shoud use jax.lax.fori_loop
", etc.)
With these, two main points come up, which i am not sure about:
while_loop
and have cond_fun
do the checks that could break the loop?y_history
, s_history
and rho_history
). Would I just pack all of these into val
(the (2nd) argument of body_fun
)?Someone takes my implementation and improves it. You don't necessarily need to understand the details of the implementation.
I'd also love to see an implementation of Levenberg-Marquardt. The Matlab ML toolkit includes an implementation and it's regularly used to train for timeseries (regression) RNN (NARX) models - it seems to perform much better than gradient descent. I'd love to be able to validate this and compare with other 2nd order solvers.
I'll watch the progress of this thread with interest and try and contribute where I can and if/when I feel I understand enough I might give it a go.
@Jakob-Unfried Let me suggest an iterative process:
jit
, starting with the inner-most layers and working outwards. We can verify that this works by adding explicit @jit
annotations.This can and should happen over multiple PRs. As long you or others express willingness to help, I am happy to start on the process.
As for "how to write code that works well with jit
", the general guidance is:
lax.while_loop
. Yes, this typically means you need to pack a lot of extra state into the val
.if
statements into np.where
(usually it isn't worth using lax.cond
).Also happy to help !
@shoyer Sounds good! I will write (and run) a full set of tests and PR the version with pure-python control-flow later today.
I will then get back here with a TODO list.
@TuanNguyen27 Let's chat on Gitter?
@Jakob-Unfried sounds good, I just sent you a gitter DM to follow up
What is the benefit of adding a minimizer? The minimizer calls the cost function and its gradient. Both can be JAX accelerated already and the functions dominate the computing time. I don't see how one could gain here. The power of JAX is that you can mix with other libraries.
You absolutely can and should leverage optimizers from ScPy for many applications. I'm interested in optimizers written in JAX for use cases that involve nested optimization, like meta-optimization.
On Fri, Feb 21, 2020 at 10:31 PM Hans Dembinski notifications@github.com wrote:
What is the benefit of adding a minimizer? The minimizer calls the cost function and its gradient. Both can be JAX accelerated already and the functions dominate the computing time. I don't see how one could gain here. The power of JAX is that you can mix with other libraries.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1400?email_source=notifications&email_token=AAJJFVQJNF5TSK5ZPYXIA5DREDBDDA5CNFSM4I227AG2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEMUYSEI#issuecomment-589924625, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVU3BAFMHLSSTCN7BDLREDBDDANCNFSM4I227AGQ .
Ok, suppose that makes sense, nevermind then.
@Jakob-Unfried it looks like a great start, but ideally we would use JAX's functional control flow so it is compatible with
jit
. There is lots of scalar arithmetic and book-keeping that XLA could likely accelerate nicely. Note: the implementation doesn't need to be differentiable, since we can define the gradient pass as a separate root finding problem vialax.custom_root
.
Hi @shoyer, would you be able to elaborate how you would utilize lax.custom_root
in practice to perform the gradient pass of an existing SciPy optimizer, e.g. scipy.minimize
? Is that possible?
(I tried this with an example from the SciPy docs following the jax docs, but I wasn't able to get output due to the incompatibility of tracing gradient through numpy
functions.)
It should be possible to wrap a SciPy optimizer to make it differentiable, but I haven't figured out exactly how yet. I think it would require making a new JAX primitive to give the SciPy function an abstract eval rule first.
On Thu, Feb 27, 2020 at 8:03 AM Nathan Simpson notifications@github.com wrote:
@Jakob-Unfried https://github.com/Jakob-Unfried it looks like a great start, but ideally we would use JAX's functional control flow so it is compatible with jit. There is lots of scalar arithmetic and book-keeping that XLA could likely accelerate nicely. Note: the implementation doesn't need to be differentiable, since we can define the gradient pass as a separate root finding problem via lax.custom_root.
Hi @shoyer https://github.com/shoyer, would you be able to elaborate how you would utilize lax.custom_root in practice to perform the gradient pass of an existing SciPy optimizer, e.g. scipy.minimize? Is that possible?
(I tried this with an example from the SciPy docs https://scipy.github.io/devdocs/generated/scipy.optimize.minimize.html#scipy.optimize.minimize following the jax docs https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.custom_root.html#jax.lax.custom_root, but I wasn't able to get output due to the incompatibility of tracing gradient through numpy functions.)
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1400?email_source=notifications&email_token=AAJJFVUHFFNHILCE2PUX5H3RE7P35A5CNFSM4I227AG2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOENE4MSI#issuecomment-592037449, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVXIBRH6CYR5JVLJJTDRE7P35ANCNFSM4I227AGQ .
Hi @shoyer and others. I made a PR for a jittable implementation of BFGS . See here #3101.
I am confused about the benefits of implementing BFGS or any other non-trivial optimisation solvers in JAX.
If it's "autodiff" what we want, then this is addressed by the so-called "sensitivity analysis" in optimisation. A software tools that follows this approach is sIPOPT (see this paper for more details). Note that tools like IPOPT and sIPOPT allow for the solution (and gradients) of generic, constrained, problems of the form:
minimize f(x)
subject to l_i <= g_i(x) <= u_i
In contrast, the BFGS algorithm implemented in #3101 only supports for unconstrained problems. This is a step back compared to L-BFGS-B of SciPy that is memory efficient and allows for bound constraints in the variables.
Moreover, I would guess that automatic differentiation on pure JAX solvers like the one in #3101 might be considerably slower than the approach of sIPOPT.
I apologise for the non-constructive comment - I believe that Jax can be an amazing tool for optimization, but rewriting elaborate solvers in pure JAX does not seem useful to me - except for research purposes. In my opinion,
sound like the features most needed by the optimisation community.
@nrontsis Support for constrained optimization problems would certainly be nice to have as well, but there are plenty of use cases for unconstrained optimization, too. One good example that is relevant to JAX's users is optimizing the weights of a neural network. The number of 👍 on this issue and the fact that L-BFGS exists in PyTorch are both good indications that it would be a welcome feature.
The BFGS PR from #3101 is only a first step. We plan to extend it both to support both the limited memory version (essential for large problems) and to support implicit differentiation, i.e., using the implicit function theorem (I assume this is what sIPOPT does).
If you're interested in collaborating on better optimization tools in JAX (including integration with external tools) we'd love to work with you. I would suggest opening new issues for specific suggestions. Quasi-Newton methods (this issue) is only one piece of the puzzle.
Thanks for the reply @shoyer. I have only started using JAX recently and I would love to contribute in the future wherever I can.
The number of 👍 on this issue and the fact that L-BFGS exists in PyTorch are both good indications that it would be a welcome feature.
I felt equally puzzled when I saw that PyTorch and TensorFlow have L-BFGS implemented (and -B variant tried unfortunately without getting merged, at least for now). What are the advantages of a native implementation over a thin-wrapper of SciPy's L-BFGS-B? You mention meta-optimization as a use case in an earlier comment - I would be interested to hear what advantages you see in this case (especially if the differentiation is done implicitly).
@nrontsis One advantage is speed when you want to embed an optimisation inside of a larger algorithm (in line with one of JAX's aims, high performance). In one of my use cases, I have an algorithm where one component of it requires doing unconstrained minimisation. When I use the embedded JAX written BFGS and write the whole algorithm in JAX and then compile it, it runs orders of magnitude faster than breaking out of XLA to use scipy with a wrapper. It's the difference between waiting days versus minutes in my case.
The main reason to rewrite algorithms in JAX is to support JAX's function transformations. In this case, that would mostly be relevant for performance:
jit
to compile over the full optimization process.vmap
an optimizers to allow for vectorization.tree_vectorize
from https://github.com/google/jax/pull/3263.I see, thank you for the explanation!
TFP-on-JAX now supports L-BFGS:
!pip install -q tfp-nightly[jax]
from tensorflow_probability.substrates import jax as tfp
from jax import numpy as jnp
def rosenbrock(coord):
"""The Rosenbrock function in two dimensions with a=1, b=100.
Args:
coord: Array with shape [2]. The coordinate of the point to evaluate
the function at.
Returns:
fv: A scalar tensor containing the value of the Rosenbrock function at
the supplied point.
dcoord: Array with shape [2]. The derivative of the function with respect to
`coord`.
"""
x, y = coord[0], coord[1]
fv = (1 - x)**2 + 100 * (y - x**2)**2
dfx = 2 * (x - 1) + 400 * x * (x**2 - y)
dfy = 200 * (y - x**2)
return fv, jnp.stack([dfx, dfy])
start = jnp.float32([-1.2, 1.0])
results = tfp.optimizer.lbfgs_minimize(
rosenbrock, initial_position=start, tolerance=1e-5)
results.position # => DeviceArray([1.0000001, 1.0000002], dtype=float32)
Also, in tomorrow's nightly build, BFGS and Nelder-Mead optimizers will be supported in TFP JAX.
Now that BFGS support has been merged and provides some general scaffolding (see #3101), we'd love to get L-BFGS (and other optimizers) in JAX proper, as well as support for pytrees and implicit differentiation.
Thanks @shoyer for your hard work getting this in. I'll be happy to add some more. I'm interested in getting (L)BFGS in with bounded constraints. As well I can add pytree support to BFGS, and as promised in my code comment in line search, I still intend to profile linesearch between the versions using where's and cond's. Op 30 jul. 2020 3:24 a.m. schreef Stephan Hoyer notifications@github.com: Now that BFGS support has been merged and provides some general scaffolding (see #3101), we'd love to get L-BFGS (and other optimizers) in JAX proper, as well as support for pytrees and implicit differentiation.
—You are receiving this because you commented.Reply to this email directly, view it on GitHub, or unsubscribe.
Hi. I'm interested in mixing an adam optimizer and lbfgs optimizer.
Jax provides an adam optimizer, so I used that. But I don't understand how I can turn the network parameters from Jax's adam optimizer to the input of tfp.optimizer.lbfgs_minimize()
.
The below code conceptually shows what I want to do. The code tries to optimize a network with adam first, and then use lbfgs.
from tensorflow_probability.substrates import jax as tfp
from jax import numpy as jnp
import numpy as np
from jax.experimental import stax, optimizers
import jax
from jax import grad
def create_mlp(num_channels = []):
modules = []
for nc in num_channels:
modules.append(stax.Dense(nc))
modules.append(stax.Softplus)
modules.append(stax.Dense(1))
return stax.serial(*modules)
def main():
# 1. create a network
net_init_random, net_apply = create_mlp([10]*3)
rng = jax.random.PRNGKey(0)
in_shape = (-1, 2)
# 2. create a gradient decent optimizer
out_shape, net_params = net_init_random(rng, in_shape)
opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
def loss(params, x, y):
return jnp.mean((net_apply(params, x) - y)**2)
def step(i, opt_state, x, y):
p = get_params(opt_state)
g = grad(loss)(p, x, y)
return opt_update(i, g, opt_state)
opt_state = opt_init(net_params)
# 3. optimize
for i in range(100):
x = np.random.random_sample((10,2))
y = np.random.random_sample((10,1))
step(i, opt_state, x, y)
# 4. lbfgs optimization
_x = np.random.random_sample((10,2))
_y = np.random.random_sample((10,1))
def func(params):
return loss(params, _x, _y)
net_params = get_params(opt_state)
results = tfp.optimizer.lbfgs_minimize(
func, initial_position=net_params, tolerance=1e-5)
if __name__ == "__main__":
main()
Any kind of advice would be very helpful to me. @brianwa84 Could you provide an example that mixes the two kinds of optimizers?
Re: @nrontsis https://github.com/google/jax/issues/1400#issuecomment-648453239.
I am confused about the benefits of implementing BFGS or any other non-trivial optimisation solvers in JAX.
I made a comparison of BFGS against scipy+numpy, scipy+jax, and pure jax on a benchmark problem, N-d least squares + L1 regularisation.
@KeunwooPark I can't speak to using TFP in combination, but maybe https://github.com/google/jax/issues/3847 is interesting to you. If you wanted to use pure JAX L-BFGS then you'll need to wait. I plan on implementing it mid-November, as well as pytree arguments.
That's quite impressive! Can you share the script used to generate it?
Just put it up here: https://gist.github.com/Joshuaalbert/214f14bbdd55d413693b8b413a384cae
EDIT: the scipy+numpy and scipy+jitted(func) do numerical jacobians which requires many more function evaluations. Best to compare pure JAX and scipy+jitted(fund and grad)
@Joshuaalbert Thank you for your comment! I guess I have to think more about a workaround or just use TensorFlow. Or use your L-BFGS implementation. According to your graph, mixing scipy and jax doesn't seem to be a good idea. I'm interested in implementing L-BFGS, but I'm really new to these concepts and still learning.
@KeunwooPark If you're looking for something immediately, scipy+jax will probably do the job. Unless the speed is crucial.
@KeunwooPark I modified your example to feed the neural network weights into L-BFGS. Unfortunately, tfp.optimize.lbfgs_minimize
does not support optimizing over structures of tensors/arrays, but you can concatenate the network weights together, and then split them back up in the loss function.
from tensorflow_probability.substrates import jax as tfp
from jax import numpy as jnp
import numpy as np
from jax.experimental import stax, optimizers
import jax
from jax import grad
def create_mlp(num_channels = []):
modules = []
for nc in num_channels:
modules.append(stax.Dense(nc))
modules.append(stax.Softplus)
modules.append(stax.Dense(1))
return stax.serial(*modules)
def main():
# 1. create a network
net_init_random, net_apply = create_mlp([10]*3)
rng = jax.random.PRNGKey(0)
in_shape = (-1, 2)
# 2. create a gradient decent optimizer
out_shape, net_params = net_init_random(rng, in_shape)
opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
def loss(params, x, y):
return jnp.mean((net_apply(params, x) - y)**2)
def step(i, opt_state, x, y):
p = get_params(opt_state)
g = grad(loss)(p, x, y)
return opt_update(i, g, opt_state)
opt_state = opt_init(net_params)
# 3. optimize
for i in range(100):
x = np.random.random_sample((10,2))
y = np.random.random_sample((10,1))
step(i, opt_state, x, y)
# 4. lbfgs optimization
_x = np.random.random_sample((10,2))
_y = np.random.random_sample((10,1))
net_params = get_params(opt_state)
def concat_params(params):
flat_params, params_tree = jax.tree_util.tree_flatten(params)
params_shape = [x.shape for x in flat_params]
return jnp.concatenate([x.reshape(-1) for x in flat_params]), (params_tree, params_shape)
param_vector, (params_tree, params_shape) = concat_params(net_params)
@jax.value_and_grad
def func(param_vector):
split_params = jnp.split(param_vector,
np.cumsum([np.prod(s) for s in params_shape[:-1]]))
flat_params = [x.reshape(s) for x, s in zip(split_params, params_shape)]
params = jax.tree_util.tree_unflatten(params_tree, flat_params)
return loss(params, _x, _y)
results = tfp.optimizer.lbfgs_minimize(
jax.jit(func), initial_position=param_vector, tolerance=1e-5)
if __name__ == "__main__":
main()
Hope this is helpful! In the long run, better support for these structured inputs to L-BFGS in TFP will likely be the path forward.
@sharadmv Wow. Thank you very much! This helps me a lot.
Is there any interest in adding a quasi-Newton based optimizer? I was thinking of porting over:
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/optimizer/bfgs.py
But wasn't sure if anyone else was interested or had something similar already.