patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.43k stars 128 forks source link

Value error when used in gradient optimization with equinox latest version #502

Closed ParticularlyPythonicBS closed 1 month ago

ParticularlyPythonicBS commented 1 month ago

Using diffrax ode integration within a equinox nn training loop throws the error:

ValueError: Closure-converted function called with different dynamic arguments to the example arguments provided: ...

traceback.txt

stderr attached since its too long to paste into the issue.

on equinox version 0.11.6, jax version 0.4.33, and diffrax 0.6.0 while it works perfectly fine on equinox version 0.10.6, jax version 0.4.13, and diffrax 0.4.0.

Here is an MVE that replicates the traceback provided:

import diffrax
import jax.numpy as jnp
import jax
import equinox as eqx
import optax

def odeint(dynamics):
    def integrator(args, ts, y0):
        terms = diffrax.ODETerm(dynamics)
        t0 = ts[0]
        t1 = ts[-1]
        dt0 = ts[1] - ts[0]
        saveat = diffrax.SaveAt(ts=ts)
        sol = diffrax.diffeqsolve(
            terms,
            diffrax.Tsit5(),
            t0, t1, dt0,
            y0,
            args=args,
            saveat=saveat,
        )
        return sol.ys
    return integrator

def dynamics(t, y, args):
    dy = args(t) * y
    return dy

class NN(eqx.Module):
    layer: eqx.nn.Linear

    def __init__(self, key):
        self.layer = eqx.nn.Linear(1, 1, key=key)

    def __call__(self, t):
        t = jnp.array(t).flatten()
        u = self.layer(t)
        return u

def compute_loss(args, integrator, x, ts):
    loss = integrator(args, ts, x)[-1].mean()
    return loss

def make_step(controller, integrator, x, ts, optim, opt_state):
    grads = eqx.filter_grad(compute_loss)(controller, integrator, x, ts)
    updates, opt_state = optim.update(grads, opt_state)
    controller = eqx.apply_updates(controller, updates)
    return controller, opt_state

ts = jnp.arange(0.0, 1, 0.1)
y0 = jnp.array([1])

model_key = jax.random.PRNGKey(0)
neural_net = NN(model_key)
integrator = odeint(dynamics)

optimizer = optax.sgd(learning_rate=3e-3)

opt_state = optimizer.init(eqx.filter(neural_net, eqx.is_inexact_array))

neural_net, opt_state = make_step(
    neural_net,
    integrator,
    y0,
    ts,
    optimizer,
    opt_state,
)

As far as I'm aware there have been no deprecation warnings for any of this code.

Is there a better way to perform this task where an equinox neural network gives the argument for the function to be integrated using diffrax?

lockwo commented 1 month ago

This looks like the same issue that in equinox that came from the new weak_type struct in jax 0.4.33 (see https://github.com/patrick-kidger/equinox/issues/854, https://github.com/google/jax/issues/23690).

With diffrax 0.6.0, equinox 0.11.6, and jax 0.4.31, it works.

ParticularlyPythonicBS commented 1 month ago

Thank you so much, freezing the jax version fixes it for now. Hope the upstream issue is fixed soon.

patrick-kidger commented 1 month ago

Closing as fixed in Equinox v0.11.7 / https://github.com/patrick-kidger/equinox/pull/856 ! Thanks for the report :)

ParticularlyPythonicBS commented 1 month ago

Closing since fixed and last comment intended to close it