google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
935 stars 66 forks source link

Gradient through closure #285

Open JTT94 opened 2 years ago

JTT94 commented 2 years ago

Is it possible to do the following?

I would like to take a gradient through a an argmin involving a closure, where the function passed to the solver contains y, which is not passed as an argument explicitly. This seems to be causing an issue.

from jaxopt import LBFGS
import jax.numpy as jnp
import jax

def implicit_layer(y):
  def closure(x):
    loss = jnp.sum(x**2) - jnp.sum(x * y)
    return loss

  lbfgs = LBFGS(fun=closure, tol=1e-5, stepsize=1e-1, maxiter=100, history_size=5,
                  use_gamma=True)
  out, _ = lbfgs.run(y)
  return out

jax.grad(implicit_layer)(y)

(this is just a toy example of what I have in mind)

The error message is:

CustomVJPException: Detected differentiation of a custom_vjp function with respect to a closed-over value. That isn't supported because the custom VJP rule only specifies how to differentiate the custom_vjp function with respect to explicit input parameters. Try passing the closed-over value into the custom_vjp function as an argument, and adapting the custom_vjp fwd and bwd rules.
mblondel commented 2 years ago

This is a frequently-asked question. You need to write it like this:

from jaxopt import LBFGS
import jax.numpy as jnp
import jax

def implicit_layer(outer_y):
  def closure(x, inner_y):
    loss = jnp.sum(x**2) - jnp.sum(x * inner_y)
    return loss

  lbfgs = LBFGS(fun=closure, tol=1e-5, stepsize=1e-1, maxiter=100, history_size=5,
                  use_gamma=True)
  out, _ = lbfgs.run(init_x, inner_y=outer_y)
  return out

jax.grad(implicit_layer)(y)
JTT94 commented 2 years ago

Thank you! This works, apologies if it's in the docs and I missed it.

Also, there is an LBFGS optimizer in tfp (does not require tensorflow itself) and works well with Jax see https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX. On some initial experiments it seemed a bit faster than the jaxopt / scipy implementations and does not require scipy numpy wrappers. However, I have not tried the implicit gradient for a closure as mentioned above.

Let me know if it's worth adding, happy to create a PR, something like this:


from jaxopt._src import base
from dataclasses import dataclass
from typing import Tuple, Callable,Optional, Dict, Any
from jaxopt._src import implicit_diff as idf

from tensorflow_probability.substrates import jax as jax_tfp
import jax 
import jax.numpy as jnp
import numpy as onp

@dataclass(eq=False)
class TFPLBFGS(base.Solver):

  fun: Callable = None
  tol: Optional[float] = 1e-7
  maxiter: int = 500
  jit: bool = True
  implicit_diff_solve: Optional[Callable] = None
  has_aux: bool = False

  def optimality_fun(self, sol, *args, **kwargs):
    """Optimality function mapping compatible with `@custom_root`."""
    return self._grad_fun(sol, *args, **kwargs)

  def _run(self, init_params, bounds, *args, **kwargs):

    optim_results = jax_tfp.optimizer.lbfgs_minimize(
        self._value_and_grad_fun, 
        initial_position=init_params, 
        max_iterations=self.maxiter,
        tolerance=self.tol)

    return base.OptStep(optim_results.position, optim_results)

  def run(self,
          init_params: Any,
          *args,
          **kwargs) -> base.OptStep:
    """Runs the solver.
    Args:
      init_params: pytree containing the initial parameters.
      *args: additional positional arguments to be passed to `fun`.
      **kwargs: additional keyword arguments to be passed to `fun`.
    Returns:
      (params, info).
    """
    return self._run(init_params, None, *args, **kwargs)

  def __post_init__(self):
    # Set up implicit diff.
    decorator = idf.custom_root(self.optimality_fun,
                                has_aux=True,
                                solve=self.implicit_diff_solve)
    # pylint: disable=g-missing-from-attributes
    self.run = decorator(self.run)

    if self.has_aux:
      self.fun = lambda x, *args, **kwargs: self.fun(x, *args, **kwargs)[0]

    # Pre-compile useful functions.
    self._grad_fun = jax.grad(self.fun)
    self._value_and_grad_fun = jax.value_and_grad(self.fun)
    if self.jit:
      self._grad_fun = jax.jit(self._grad_fun)
      self._value_and_grad_fun = jax.jit(self._value_and_grad_fun)
mblondel commented 2 years ago

jaxopt.LBFGS is also a pure-JAX implementation (jaxopt.ScipyMinimize is a wrapper). You may want to try jaxopt.LBFGS(..., linesearch="zoom"), which could potentially give better results. This line search is scheduled to become the default one soon, replacing the backtracking line search.

Regarding the implementation in TFP, the key difference is that it's using the so-called Hager-Zhang line search. We have been discussing of including this line search in JAXopt. We should be able to copy the generated JAX code from TFP (contribution welcome!).

Btw, have you see this and this?