patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 130 forks source link

Jax 0.4.27 safe_map() error for equinox+diffrax #415

Closed aidancrilly closed 6 months ago

aidancrilly commented 6 months ago

I have found an issue which occurs with jax 0.4.27 but not 0.4.26 when using diffrax to solve a neural ODE. I am using the following package versions: diffrax 0.5.0 equinox 0.11.4

Here is an example where the error occurs:

import diffrax
import equinox as eqx
import jax.numpy as jnp
import jax.nn as jnn
import jax.random as jrandom

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, in_size, out_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=in_size,
            out_size=out_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.tanh,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y)

class NeuralODE(eqx.Module):
    func: Func

    def __init__(self, in_size, out_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(in_size, out_size, width_size, depth, key=key)

    def __call__(self, ts, y0, args):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Heun(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=y0,
            args=args,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts)
        )
        return solution.ys

times = jnp.linspace(0.0,1.0,100)
args  = {}

seed = 0
key = jrandom.PRNGKey(seed)
__, model_key = jrandom.split(key)
model = NeuralODE(3, 3, 32, 3, key=model_key)
result = model(times,jnp.zeros(3),args)

This runs without error for 0.4.26 but the following error occurs for 0.4.27

  File "f:\Anaconda3\envs\HydroSurrogate\Lib\site-packages\equinox\internal\_loop\checkpointed.py", line 268, in <lambda>
    _body_fun = lambda x: body_fun(x)  # hashable wrapper; JAX issue #13554
                          ^^^^^^^^^^^
ValueError: safe_map() argument 2 is shorter than argument 1
patrick-kidger commented 6 months ago

Thanks for the report! This is an upstream JAX bug in google/jax#21116. It should be fixed soon in JAX 0.4.28.

aidancrilly commented 6 months ago

Great! I shall close this issue