Closed eringrant closed 7 months ago
Thank you! I'm glad you're enjoying the library.
So the reason for this is that Optimistix requires y
to have type PyTree[Array]
: that is, all of the leaves of this PyTree must be arrays. However, eqx.nn.MLP
is a PyTree containing both arrays and non-arrays (in this case, the activation function, which is the custom_jvp
object mentioned in the error message).
Probably the simplest way to tackle this is to use eqx.{partition.combine}
, as in "Option 1" here.
We could maybe lift this requirement in Optimistix itself: this wasn't a priority to begin with, but it could probably be done with a lot of little tweaks.
Thanks so much for the quick response! I've indeed been able to get the example working with the eqx.{partition.combine}
strategy (code below) like you're suggesting. I'm happy with this usage pattern for now but it would be nice to be able to use filtered transformations (as a lot of my code does so!).
eqx.{partition,combine}
import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx
N = K = 8
x = jnp.linspace(0, 1, N)[None, ...]
y = x**2
model = eqx.nn.MLP(
in_size=N,
out_size=N,
width_size=K,
depth=1,
activation=jax.nn.relu,
key=jax.random.PRNGKey(42),
)
def loss(params, args):
static, x, y = args
model = eqx.combine(params, static)
pred_y = eqx.filter_vmap(model)(x)
loss = jnp.mean((pred_y - y) ** 2)
aux = None
return loss, aux
optimizer = optx.GradientDescent(learning_rate=1e-1, rtol=1e-4, atol=1e-4)
options = None
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()
init = eqx.Partial(optimizer.init, fn=loss, options=options, f_struct=f_struct,
aux_struct=aux_struct, tags=tags)
step = eqx.Partial(optimizer.step, fn=loss, options=options, tags=tags)
terminate = eqx.Partial(optimizer.terminate, fn=loss, options=options, tags=tags)
postprocess = eqx.Partial(optimizer.postprocess, fn=loss, options=options, tags=tags)
params, static = eqx.partition(model, eqx.is_array)
state = init(y=params, args=(static, x, y))
done, result = terminate(y=params, args=(static, x, y), state=state)
while not done:
params, static = eqx.partition(model, eqx.is_array)
params, state, _ = step(y=params, args=(static, x, y), state=state)
done, result = terminate(y=params, args=(static, x, y), state=state)
model = eqx.combine(params, static)
print(f"Evaluating iteration with loss value {loss(params, (static, x, y))[0]}.")
if result != optx.RESULTS.successful:
print("Failed!")
params, static = eqx.partition(model, eqx.is_array)
params, _, _ = postprocess(
y=params,
aux=None,
args=(static, x, y),
state=state,
result=result,
)
print(f"Found solution with loss value {loss(params, (static, x, y))[0]}.")
Thanks very much for this library! Though I understand it's not the primary use case, I'd like to use
optimistix
with first-order gradient optimizers and standard neural nets to make use of the ability to vectorize optimizers. (Specifically, I'd like to train an ensemble like inequinox
, but where each member of the ensemble is paired with a distinct optimizer.)I run into an error when using
optx.GradientDescent
with aneqx.Module
. Adapting some example code from this repo for a MWE:gives me:
at this line: https://github.com/patrick-kidger/optimistix/blob/53d017dd7fd125dbedaeb710070ab8171e7d4ae8/optimistix/_solver/gradient_methods.py#L155
which, if I understand correctly, is the result of
jax.eval_shape
hitting non-arrays. How can I filter for arrays inmodel
, or is there a different recommended usage pattern here?