patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
312 stars 15 forks source link

TypeError #10

Closed pharringtonp19 closed 1 year ago

pharringtonp19 commented 1 year ago

Excited to explore the library as always!

class MiniData(NamedTuple):
    X: ArrayImpl
    Y: ArrayImpl

def loss_fn_per_obs(y, p):
    return jnp.where(y==1.0, -jnp.log(p ), -jnp.log(1-p ))

def fn(params, args):
    P =  jax.nn.sigmoid(args.X @ params)
    losses = jax.vmap(loss_fn_per_obs, in_axes=(0,0))(args.Y, P)
    return jnp.mean(losses)

init_params = jax.random.normal(jax.random.PRNGKey(0), shape=(19,1))
data = MiniData(X=jax.random.normal(jax.random.PRNGKey(1), shape=(100, 19)),
                Y= jax.random.normal(jax.random.PRNGKey(2), shape=(100, 1)))
solver = optimistix.NonlinearCG(rtol=0.01, atol=0.01)
optimistix.minimise(fn=fn, solver=solver, y0 = init_params, args=data, has_aux=False)

I am running into the following type error:

TypeError: linearize() got an unexpected keyword argument 'has_aux'
patrick-kidger commented 1 year ago

Which version of JAX are you using? This was added in one of the more recent JAX releases. (It's possible we need to bump the minimum version required by Optimistix.)

pharringtonp19 commented 1 year ago

@patrick-kidger Everything works once I updated the necessary libraries. My bad!