google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
927 stars 64 forks source link

UnexpectedTracerError when using vmap over solvers with implicit differentiation #198

Closed tristandeleu closed 2 years ago

tristandeleu commented 2 years ago

To give a bit of context: I am trying to use jaxopt for meta-learning, and more specifically to implement iMAML. I need to call a solver for each task (each task has its own solver), where here the solvers are simply gradient descent with implicit differentiation to compute the gradients. In meta-learning the update of the parameters is usually made based on a batch of tasks, so I want to apply multiple solvers in parallel, one for each task in my batch. To do that, I would like to vmap the outer loss, which calls a GradientDescent solver.

When I am using vmap over a function that calls GradientDescent, where implicit_diff=True, I get a UnexpectedTracerError error. Here is a minimal example (it's a bit verbose sorry):

import jax.numpy as jnp
import jax
import jaxopt
import optax

from functools import partial

def loss(params, inputs, targets):
    outputs = jnp.matmul(inputs, params)
    return jnp.mean(optax.l2_loss(outputs, targets))

def adapt(init_params, inputs, targets):
    def inner_loss(params, init_params):
        return loss(params, inputs, targets) + 0.5 * jnp.sum((params - init_params) ** 2)

    solver = jaxopt.GradientDescent(
        inner_loss,
        stepsize=0.1,
        maxiter=5,
        acceleration=False,
        implicit_diff=True,
    )

    params, _ = solver.run(init_params, init_params)
    return params

def outer_loss(init_params, inputs, targets):
    @partial(jax.vmap, in_axes=(None, 0, 0))
    def outer_loss_(init_params, inputs, targets):
        adapted_params = adapt(init_params, inputs, targets)
        return loss(adapted_params, inputs, targets)
    losses = outer_loss_(init_params, inputs, targets)
    return jnp.mean(losses)

meta_optimizer = optax.adam(1e-3)

@jax.jit
def meta_update(init_params, state, inputs, targets):
    value, grads = jax.value_and_grad(outer_loss)(init_params, inputs, targets)

    # Apply gradient update
    updates, state = meta_optimizer.update(grads, state, init_params)
    init_params = optax.apply_updates(init_params, updates)

    return init_params, state, value

# Dummy data
inputs = jnp.zeros((2, 3, 5))
targets = jnp.ones((2, 3, 7))

init_params = jnp.zeros((5, 7))
state = meta_optimizer.init(init_params)

for _ in range(10):
    init_params, state, value = meta_update(init_params, state, inputs, targets)
    print(value)

I get the following error:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (5, 7) and dtype float32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Different traces at same level: Traced<ShapedArray(float32[5,7])>with<JVPTrace(level=3/1)> with
  primal = Traced<ShapedArray(float32[5,7])>with<JaxprTrace(level=2/1)> with
    pval = (None, Traced<ShapedArray(float32[5,7])>with<BatchTrace(level=1/1)> with
  val = Traced<ShapedArray(float32[2,5,7])>with<DynamicJaxprTrace(level=0/1)>
  batch_dim = 0)
    recipe = *
  tangent = Traced<ShapedArray(float32[5,7])>with<JaxprTrace(level=2/1)> with
    pval = (ShapedArray(float32[5,7]), *)
    recipe = LambdaBinding(), BatchTrace(level=3/1)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

The same snippet works when I don't apply vmap (I'm working with a single task):

```python import jax.numpy as jnp import jax import jaxopt import optax from functools import partial def loss(params, inputs, targets): outputs = jnp.matmul(inputs, params) return jnp.mean(optax.l2_loss(outputs, targets)) def adapt(init_params, inputs, targets): def inner_loss(params, init_params): return loss(params, inputs, targets) + 0.5 * jnp.sum((params - init_params) ** 2) solver = jaxopt.GradientDescent( inner_loss, stepsize=0.1, maxiter=5, acceleration=False, implicit_diff=True, ) params, _ = solver.run(init_params, init_params) return params def outer_loss(init_params, inputs, targets): adapted_params = adapt(init_params, inputs, targets) return loss(adapted_params, inputs, targets) meta_optimizer = optax.adam(1e-3) @jax.jit def meta_update(init_params, state, inputs, targets): value, grads = jax.value_and_grad(outer_loss)(init_params, inputs, targets) # Apply gradient update updates, state = meta_optimizer.update(grads, state, init_params) init_params = optax.apply_updates(init_params, updates) return init_params, state, value # Dummy data inputs = jnp.zeros((3, 5)) targets = jnp.ones((3, 7)) init_params = jnp.zeros((5, 7)) state = meta_optimizer.init(init_params) for _ in range(10): init_params, state, value = meta_update(init_params, state, inputs, targets) print(value) ```

It also works if I'm using vmap with unrolled optimization instead of implicit differentiation (unroll=True, implicit_diff=False):

```python import jax.numpy as jnp import jax import jaxopt import optax from functools import partial def loss(params, inputs, targets): outputs = jnp.matmul(inputs, params) return jnp.mean(optax.l2_loss(outputs, targets)) def adapt(init_params, inputs, targets): def inner_loss(params, init_params): return loss(params, inputs, targets) + 0.5 * jnp.sum((params - init_params) ** 2) solver = jaxopt.GradientDescent( inner_loss, stepsize=0.1, maxiter=5, acceleration=False, implicit_diff=False, unroll=True ) params, _ = solver.run(init_params, init_params) return params def outer_loss(init_params, inputs, targets): @partial(jax.vmap, in_axes=(None, 0, 0)) def outer_loss_(init_params, inputs, targets): adapted_params = adapt(init_params, inputs, targets) return loss(adapted_params, inputs, targets) losses = outer_loss_(init_params, inputs, targets) return jnp.mean(losses) meta_optimizer = optax.adam(1e-3) @jax.jit def meta_update(init_params, state, inputs, targets): value, grads = jax.value_and_grad(outer_loss)(init_params, inputs, targets) # Apply gradient update updates, state = meta_optimizer.update(grads, state, init_params) init_params = optax.apply_updates(init_params, updates) return init_params, state, value # Dummy data inputs = jnp.zeros((2, 3, 5)) targets = jnp.ones((2, 3, 7)) init_params = jnp.zeros((5, 7)) state = meta_optimizer.init(init_params) for _ in range(10): init_params, state, value = meta_update(init_params, state, inputs, targets) print(value) ```

Here are the version of jax/jaxopt I am using:

fllinares commented 2 years ago

Hi Tristan, I believe the error is caused by closing over inputs and targets in inner_loss. The following version of your MWE works for me:

def loss(params, inputs, targets):
  outputs = jnp.matmul(inputs, params)
  return jnp.mean(optax.l2_loss(outputs, targets))

def inner_loss(params, init_params, inputs, targets):
  loss_val = loss(params, inputs, targets)
  prox_val = 0.5 * jnp.sum((params - init_params) ** 2)
  return loss_val + prox_val

solver = jaxopt.GradientDescent(
    inner_loss,
    stepsize=0.1,
    maxiter=5,
    acceleration=False,
    implicit_diff=True,
)

def adapt(init_params, inputs, targets):
  params, _ = solver.run(init_params, init_params, inputs, targets)
  return params

@partial(jax.vmap, in_axes=(None, 0, 0))
def outer_loss_(init_params, inputs, targets):
  adapted_params, _ = solver.run(init_params, init_params, inputs, targets)
  return loss(adapted_params, inputs, targets)

def outer_loss(init_params, inputs, targets):
  losses = outer_loss_(init_params, inputs, targets)
  return jnp.mean(losses)

meta_optimizer = optax.adam(1e-3)

@jax.jit
def meta_update(init_params, state, inputs, targets):
    value, grads = jax.value_and_grad(outer_loss)(init_params, inputs, targets)

    # Apply gradient update
    updates, state = meta_optimizer.update(grads, state, init_params)
    init_params = optax.apply_updates(init_params, updates)

    return init_params, state, value

# Dummy data
inputs = jnp.zeros((2, 3, 5))
targets = jnp.ones((2, 3, 7))

init_params = jnp.zeros((5, 7))
state = meta_optimizer.init(init_params)

for _ in range(10):
    init_params, state, value = meta_update(init_params, state, inputs, targets)
    print(value)
tristandeleu commented 2 years ago

That's a very simple fix, thanks a lot Felipe!