patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
265 stars 12 forks source link

Can't use Optimistix solvers with `eqx.Module`s and filtered transformations #26

Closed eringrant closed 7 months ago

eringrant commented 7 months ago

Thanks very much for this library! Though I understand it's not the primary use case, I'd like to use optimistix with first-order gradient optimizers and standard neural nets to make use of the ability to vectorize optimizers. (Specifically, I'd like to train an ensemble like in equinox, but where each member of the ensemble is paired with a distinct optimizer.)

I run into an error when using optx.GradientDescent with an eqx.Module. Adapting some example code from this repo for a MWE:

import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx

N = K = 8
x = jnp.linspace(0, 1, N)[None, ...]
y = x**2

model = eqx.nn.MLP(
  in_size=N,
  out_size=N,
  width_size=K,
  depth=1,
  activation=jax.nn.relu,
  key=jax.random.PRNGKey(42),
)

@eqx.filter_jit
def loss(model, args):
  x, y = args
  pred_y = eqx.filter_vmap(model)(x)
  loss = jnp.mean((pred_y - y) ** 2)
  aux = None
  return loss, aux

optimizer = optx.GradientDescent(learning_rate=1e-1, rtol=1e-4, atol=1e-4)
options = None
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()

init = eqx.filter_jit(eqx.Partial(optimizer.init, fn=loss, options=options, f_struct=f_struct,
                      aux_struct=aux_struct, tags=tags))
step = eqx.filter_jit(eqx.Partial(optimizer.step, fn=loss, options=options, tags=tags))
terminate = eqx.filter_jit(eqx.Partial(optimizer.terminate, fn=loss, options=options, tags=tags))
postprocess = eqx.filter_jit(eqx.Partial(optimizer.postprocess, fn=loss, options=options, tags=tags))

state = init(y=model, args=(x, y))
done, result = terminate(y=model, args=(x, y), state=state)

while not done:
  model, state, _ = step(y=model, args=(x, y), state=state)
  done, result = terminate(y=model, args=(x, y), state=state)
  print(f"Evaluating iteration with loss value {loss(model, (x, y))[0]}.")

if result != optx.RESULTS.successful:
  print("Failed!")

model, _, _ = postprocess(
  y=model,
  aux=None,
  args=(x, y),
  state=state,
  result=result,
)
print(f"Found solution with loss value {loss(model, (x, y))[0]}.")

gives me:

TypeError: Value <jax._src.custom_derivatives.custom_jvp object at 0x1022171d0> with type <class 'jax._src.custom_derivatives.custom_jvp'> is not a valid JAX type

at this line: https://github.com/patrick-kidger/optimistix/blob/53d017dd7fd125dbedaeb710070ab8171e7d4ae8/optimistix/_solver/gradient_methods.py#L155

which, if I understand correctly, is the result of jax.eval_shape hitting non-arrays. How can I filter for arrays in model, or is there a different recommended usage pattern here?

patrick-kidger commented 7 months ago

Thank you! I'm glad you're enjoying the library.

So the reason for this is that Optimistix requires y to have type PyTree[Array]: that is, all of the leaves of this PyTree must be arrays. However, eqx.nn.MLP is a PyTree containing both arrays and non-arrays (in this case, the activation function, which is the custom_jvp object mentioned in the error message).

Probably the simplest way to tackle this is to use eqx.{partition.combine}, as in "Option 1" here.

We could maybe lift this requirement in Optimistix itself: this wasn't a priority to begin with, but it could probably be done with a lot of little tweaks.

eringrant commented 7 months ago

Thanks so much for the quick response! I've indeed been able to get the example working with the eqx.{partition.combine} strategy (code below) like you're suggesting. I'm happy with this usage pattern for now but it would be nice to be able to use filtered transformations (as a lot of my code does so!).

MWE with eqx.{partition,combine}

import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx

N = K = 8
x = jnp.linspace(0, 1, N)[None, ...]
y = x**2

model = eqx.nn.MLP(
  in_size=N,
  out_size=N,
  width_size=K,
  depth=1,
  activation=jax.nn.relu,
  key=jax.random.PRNGKey(42),
)

def loss(params, args):
  static, x, y = args
  model = eqx.combine(params, static)
  pred_y = eqx.filter_vmap(model)(x)
  loss = jnp.mean((pred_y - y) ** 2)
  aux = None
  return loss, aux

optimizer = optx.GradientDescent(learning_rate=1e-1, rtol=1e-4, atol=1e-4)
options = None
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()

init = eqx.Partial(optimizer.init, fn=loss, options=options, f_struct=f_struct,
                   aux_struct=aux_struct, tags=tags)
step = eqx.Partial(optimizer.step, fn=loss, options=options, tags=tags)
terminate = eqx.Partial(optimizer.terminate, fn=loss, options=options, tags=tags)
postprocess = eqx.Partial(optimizer.postprocess, fn=loss, options=options, tags=tags)

params, static = eqx.partition(model, eqx.is_array)
state = init(y=params, args=(static, x, y))
done, result = terminate(y=params, args=(static, x, y), state=state)

while not done:
  params, static = eqx.partition(model, eqx.is_array)
  params, state, _ = step(y=params, args=(static, x, y), state=state)
  done, result = terminate(y=params, args=(static, x, y), state=state)
  model = eqx.combine(params, static)
  print(f"Evaluating iteration with loss value {loss(params, (static, x, y))[0]}.")

if result != optx.RESULTS.successful:
  print("Failed!")

params, static = eqx.partition(model, eqx.is_array)
params, _, _ = postprocess(
  y=params,
  aux=None,
  args=(static, x, y),
  state=state,
  result=result,
)
print(f"Found solution with loss value {loss(params, (static, x, y))[0]}.")