patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
334 stars 14 forks source link

Zero implicit gradients when using `ImplicitAdjoint` with CG solver #54

Open itk22 opened 8 months ago

itk22 commented 8 months ago

Hi @patrick-kidger and @packquickly,

I was trying to implement the following meta-learning example from jax-opt in optimistix: Few-shot Adaptation with Model Agnostic Meta-Learning . However, I ran into an issue with implicit differentiation through the inner loop. The below example runs well when using optx.RecursiveCheckpointAdjoint but when I try to recreate the iMAML setup by putting optx.ImplicitAdjoint with a CG solver with 20 steps, all the meta-gradients are zero, and the meta-optimiser doesn't change at all in the training. Could you please help me identify the issue with the code? It seems to be an implementation detail for implicit adjoints that differs between jax-opt and optimistic.

Here is an MWE:

import optimistix as optx
import equinox as eqx
import lineax as lx
import jax
import jax.random as jr
import jax.numpy as jnp
import optax

key = jr.PRNGKey(0)
model = eqx.nn.MLP(1, 1, 40, 2, key=key)

sine_target = lambda x: 1.0 * jnp.sin(x - 0.5) # Target function
x = jr.normal(key, (10, 1)) # Randomly drawn inputs for validation
y_true = sine_target(x)

opt = optx.OptaxMinimiser(optax.adam(1e-3, eps_root=1e-8), 1e-7, 1e-7)
params, static = eqx.partition(model, eqx.is_inexact_array)

def apply_model(params, x):
    model = eqx.combine(params, static)
    return jax.vmap(model)(x)

def loss_fn(params, args):
    y_pred = apply_model(params, x)
    loss = jnp.mean(jnp.square(y_pred - y_true))
    return loss, loss

def adapt_fn(params):
    sol = optx.minimise(loss_fn,
                        opt,
                        params,
                        None,
                        has_aux=True,
                        max_steps=2,
                        throw=False,
                        adjoint=optx.ImplicitAdjoint(lx.CG(1e-7, 1e-7, max_steps=10)),
                        tags=lx.positive_semidefinite_tag)
    return sol.aux # Return the final loss only

loss, grad = jax.value_and_grad(adapt_fn)(params)

print(f"Final loss: {loss:.5f}")
print(f"Gradient: {grad.layers[0].weight}")
patrick-kidger commented 8 months ago

Hey there! Thanks for the issue.

Would you be able to condense your code down to a single MWE? (Preferably around 20 lines of code.) For example we probably don't need the details of your training loop, the fact that it's batched, etc. Moreover I'm afraid this code won't run -- if nothing else, it currently doesn't have any import statements.

itk22 commented 8 months ago

Hi @patrick-kidger,

I updated the original post with a condensed MWE.

In the above code, when I use optx.RecursiveCheckpointAdjoint(), I am able to recover the correct gradients. However, when I use optx.ImplicitAdjoint with a solver specified as CG, the gradients are all exactly zero. To be fair, I was not expecting this to work out of the box because 1. adapt_fn does not find the exact solution to the inner optimization problem, 2. even for a small network, this seems to be a rather difficult calculation. However, the jax-opt example I shared above indicates that the gradients can be calculated correctly in a similar scenario using CG:

Because of this, I started wondering if there is a fundamental difference in how implicit adjoints are calculated in the two packages. My instinct is that the mismatch might have to do with handling of higher-order terms but I am curious to hear your opinion and whether it is something that can be quickly patched.

patrick-kidger commented 7 months ago

So in a minimisation problem, the solution stays constant as you peturb the initial parameters. Regardless of where you start, you should expect to converge to the same solution! So in fact a zero gradient is what is expected. (Imagine finding argmin_x x^2. It doesn't matter whether you start at x=1 or x=1.1; either way your output will be x=0.)

The fact that you get a nonzero gradient via RecursiveCheckpointAdjoint will be because of the fact that you are taking so few steps that you are not actually converging to the minima at all. (In the above example, you might only converge as far as x=0.5 or x=0.6.) So I think for a meta-learning use-case, then probably RecursiveCheckpointAdjoint is actually the correct thing to be doing!

The fact that JAXopt appears to do otherwise is possibly a bug in JAXopt. (?)

That aside, some comments on your implementation:

itk22 commented 7 months ago

Hi Patrick,

Thank you for your thorough response. It is true that the solution to a minimisation problem is independent of the initial parameters and should lead to zero gradients. As you noted, the gradients from RecursiveCheckpointAdjoint are non-zero in the above MWE because we take very few optimisation steps. I set that number low to emulate a typical bi-level meta-learning setup where, in the inner loops, we do not fully optimise the model parameters for each task but rather take just a few steps of optimisation. This is because the goal is not to find the true optimum for any single task but rather to optimise the initialisation such that the model can quickly adapt to related tasks. In this case, the gradients through the inner loop, must be non-zero for the initialisation to evolve across outer loop iterations. Also, iMAML has a special regularising term to ensure meta-gradients are not non-zero for larger number of inner steps.

So, there is no bug in JAXopt, and it only appeared so because of my incomplete explanation. The iMAML paper that I am following uses CG because it avoids forming a Hessian matrix. It also seems that CG-like iterative solvers are used quite extensively within JAXopt - as far as I understand this again has to do with their matrix-free nature. Optimistix, on the other hand seems to be using direct solvers as a default.