patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
334 stars 14 forks source link

Using `optimistix` with an `equinox` model #56

Open frostedoyster opened 6 months ago

frostedoyster commented 6 months ago

Hi everyone, thanks for the great library and apologies in advance for this basic question. I'm trying to find the true minimum of a small neural network, and I thought of using a solver from optimistix together with an equinox model. However, I haven't been able to make the two work together.

Here is a minimal snippet which fails:

import jax 
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx

jax.config.update("jax_enable_x64", True)

X = jax.random.normal(jax.random.PRNGKey(0), (2000, 8))

@jax.vmap
def function(x):
    return x[0] + x[1]**2 + jnp.cos(x[2]) + jnp.sin(x[3]) + x[4]*x[5] + (x[6]*x[7])**3

y = function(X).reshape(-1, 1)

model = eqx.nn.MLP(in_size=8, out_size=1, width_size=4, depth=2, activation=jax.nn.silu, key=jax.random.PRNGKey(0))

static, params = eqx.partition(model, eqx.is_inexact_array)

def loss_fn(params, static, X, y):
    model = eqx.combine(params, static)
    return jnp.sum((jax.vmap(model)(X) - y)**2)

solver = optx.Newton(rtol=1e-5, atol=1e-5)
sol = optx.minimise(loss_fn, solver, params)

I'm getting TypeError: Cannot determine dtype of <PjitFunction of <function silu at 0x742fde959300>>.

What am I doing wrong? Thank you in advance.

patrick-kidger commented 6 months ago

You have static and params the wrong way around.

(Caveat: I've not tried running the code, this was just what jumped out at me.)

frostedoyster commented 6 months ago

Indeed, that was the issue! Thanks a lot for the reply and for the library!