Closed aidancrilly closed 6 months ago
I have found an issue which occurs with jax 0.4.27 but not 0.4.26 when using diffrax to solve a neural ODE. I am using the following package versions: diffrax 0.5.0 equinox 0.11.4
Here is an example where the error occurs:
import diffrax import equinox as eqx import jax.numpy as jnp import jax.nn as jnn import jax.random as jrandom class Func(eqx.Module): mlp: eqx.nn.MLP def __init__(self, in_size, out_size, width_size, depth, *, key, **kwargs): super().__init__(**kwargs) self.mlp = eqx.nn.MLP( in_size=in_size, out_size=out_size, width_size=width_size, depth=depth, activation=jnn.tanh, key=key, ) def __call__(self, t, y, args): return self.mlp(y) class NeuralODE(eqx.Module): func: Func def __init__(self, in_size, out_size, width_size, depth, *, key, **kwargs): super().__init__(**kwargs) self.func = Func(in_size, out_size, width_size, depth, key=key) def __call__(self, ts, y0, args): solution = diffrax.diffeqsolve( diffrax.ODETerm(self.func), diffrax.Heun(), t0=ts[0], t1=ts[-1], dt0=ts[1] - ts[0], y0=y0, args=args, stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6), saveat=diffrax.SaveAt(ts=ts) ) return solution.ys times = jnp.linspace(0.0,1.0,100) args = {} seed = 0 key = jrandom.PRNGKey(seed) __, model_key = jrandom.split(key) model = NeuralODE(3, 3, 32, 3, key=model_key) result = model(times,jnp.zeros(3),args)
This runs without error for 0.4.26 but the following error occurs for 0.4.27
File "f:\Anaconda3\envs\HydroSurrogate\Lib\site-packages\equinox\internal\_loop\checkpointed.py", line 268, in <lambda> _body_fun = lambda x: body_fun(x) # hashable wrapper; JAX issue #13554 ^^^^^^^^^^^ ValueError: safe_map() argument 2 is shorter than argument 1
Thanks for the report! This is an upstream JAX bug in google/jax#21116. It should be fixed soon in JAX 0.4.28.
Great! I shall close this issue
I have found an issue which occurs with jax 0.4.27 but not 0.4.26 when using diffrax to solve a neural ODE. I am using the following package versions: diffrax 0.5.0 equinox 0.11.4
Here is an example where the error occurs:
This runs without error for 0.4.26 but the following error occurs for 0.4.27