Open JTT94 opened 2 years ago
This is a frequently-asked question. You need to write it like this:
from jaxopt import LBFGS
import jax.numpy as jnp
import jax
def implicit_layer(outer_y):
def closure(x, inner_y):
loss = jnp.sum(x**2) - jnp.sum(x * inner_y)
return loss
lbfgs = LBFGS(fun=closure, tol=1e-5, stepsize=1e-1, maxiter=100, history_size=5,
use_gamma=True)
out, _ = lbfgs.run(init_x, inner_y=outer_y)
return out
jax.grad(implicit_layer)(y)
Thank you! This works, apologies if it's in the docs and I missed it.
Also, there is an LBFGS optimizer in tfp (does not require tensorflow itself) and works well with Jax see https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX. On some initial experiments it seemed a bit faster than the jaxopt / scipy implementations and does not require scipy numpy wrappers. However, I have not tried the implicit gradient for a closure as mentioned above.
Let me know if it's worth adding, happy to create a PR, something like this:
from jaxopt._src import base
from dataclasses import dataclass
from typing import Tuple, Callable,Optional, Dict, Any
from jaxopt._src import implicit_diff as idf
from tensorflow_probability.substrates import jax as jax_tfp
import jax
import jax.numpy as jnp
import numpy as onp
@dataclass(eq=False)
class TFPLBFGS(base.Solver):
fun: Callable = None
tol: Optional[float] = 1e-7
maxiter: int = 500
jit: bool = True
implicit_diff_solve: Optional[Callable] = None
has_aux: bool = False
def optimality_fun(self, sol, *args, **kwargs):
"""Optimality function mapping compatible with `@custom_root`."""
return self._grad_fun(sol, *args, **kwargs)
def _run(self, init_params, bounds, *args, **kwargs):
optim_results = jax_tfp.optimizer.lbfgs_minimize(
self._value_and_grad_fun,
initial_position=init_params,
max_iterations=self.maxiter,
tolerance=self.tol)
return base.OptStep(optim_results.position, optim_results)
def run(self,
init_params: Any,
*args,
**kwargs) -> base.OptStep:
"""Runs the solver.
Args:
init_params: pytree containing the initial parameters.
*args: additional positional arguments to be passed to `fun`.
**kwargs: additional keyword arguments to be passed to `fun`.
Returns:
(params, info).
"""
return self._run(init_params, None, *args, **kwargs)
def __post_init__(self):
# Set up implicit diff.
decorator = idf.custom_root(self.optimality_fun,
has_aux=True,
solve=self.implicit_diff_solve)
# pylint: disable=g-missing-from-attributes
self.run = decorator(self.run)
if self.has_aux:
self.fun = lambda x, *args, **kwargs: self.fun(x, *args, **kwargs)[0]
# Pre-compile useful functions.
self._grad_fun = jax.grad(self.fun)
self._value_and_grad_fun = jax.value_and_grad(self.fun)
if self.jit:
self._grad_fun = jax.jit(self._grad_fun)
self._value_and_grad_fun = jax.jit(self._value_and_grad_fun)
jaxopt.LBFGS
is also a pure-JAX implementation (jaxopt.ScipyMinimize
is a wrapper). You may want to try jaxopt.LBFGS(..., linesearch="zoom")
, which could potentially give better results. This line search is scheduled to become the default one soon, replacing the backtracking line search.
Regarding the implementation in TFP, the key difference is that it's using the so-called Hager-Zhang line search. We have been discussing of including this line search in JAXopt. We should be able to copy the generated JAX code from TFP (contribution welcome!).
Is it possible to do the following?
I would like to take a gradient through a an argmin involving a closure, where the function passed to the solver contains
y
, which is not passed as an argument explicitly. This seems to be causing an issue.(this is just a toy example of what I have in mind)
The error message is: