Closed tristandeleu closed 2 years ago
Hi Tristan,
I believe the error is caused by closing over inputs
and targets
in inner_loss
. The following version of your MWE works for me:
def loss(params, inputs, targets):
outputs = jnp.matmul(inputs, params)
return jnp.mean(optax.l2_loss(outputs, targets))
def inner_loss(params, init_params, inputs, targets):
loss_val = loss(params, inputs, targets)
prox_val = 0.5 * jnp.sum((params - init_params) ** 2)
return loss_val + prox_val
solver = jaxopt.GradientDescent(
inner_loss,
stepsize=0.1,
maxiter=5,
acceleration=False,
implicit_diff=True,
)
def adapt(init_params, inputs, targets):
params, _ = solver.run(init_params, init_params, inputs, targets)
return params
@partial(jax.vmap, in_axes=(None, 0, 0))
def outer_loss_(init_params, inputs, targets):
adapted_params, _ = solver.run(init_params, init_params, inputs, targets)
return loss(adapted_params, inputs, targets)
def outer_loss(init_params, inputs, targets):
losses = outer_loss_(init_params, inputs, targets)
return jnp.mean(losses)
meta_optimizer = optax.adam(1e-3)
@jax.jit
def meta_update(init_params, state, inputs, targets):
value, grads = jax.value_and_grad(outer_loss)(init_params, inputs, targets)
# Apply gradient update
updates, state = meta_optimizer.update(grads, state, init_params)
init_params = optax.apply_updates(init_params, updates)
return init_params, state, value
# Dummy data
inputs = jnp.zeros((2, 3, 5))
targets = jnp.ones((2, 3, 7))
init_params = jnp.zeros((5, 7))
state = meta_optimizer.init(init_params)
for _ in range(10):
init_params, state, value = meta_update(init_params, state, inputs, targets)
print(value)
That's a very simple fix, thanks a lot Felipe!
To give a bit of context: I am trying to use jaxopt for meta-learning, and more specifically to implement iMAML. I need to call a solver for each task (each task has its own solver), where here the solvers are simply gradient descent with implicit differentiation to compute the gradients. In meta-learning the update of the parameters is usually made based on a batch of tasks, so I want to apply multiple solvers in parallel, one for each task in my batch. To do that, I would like to
vmap
the outer loss, which calls aGradientDescent
solver.When I am using
vmap
over a function that callsGradientDescent
, whereimplicit_diff=True
, I get aUnexpectedTracerError
error. Here is a minimal example (it's a bit verbose sorry):I get the following error:
The same snippet works when I don't apply
vmap
(I'm working with a single task):It also works if I'm using
vmap
with unrolled optimization instead of implicit differentiation (unroll=True
,implicit_diff=False
):Here are the version of jax/jaxopt I am using:
jax == 0.3.4
jaxlib == 0.3.2
jaxopt == 0.3.1