Using jax.experimental.enable_x64 and jit will return a exception.
import jax
import jax.numpy as jnp
import diffrax
key = jax.random.PRNGKey(0)
x1 = jax.random.uniform(key,(2,))
x2 = jax.random.uniform(key,(2,2))
def odes(t, y, w):
result = w @ y
return result
# Define single solve function for fixed final time with scaling
def ode_fun(y0, w):
with jax.experimental.enable_x64():
y0 = y0.astype(jnp.float64)
w = w.astype(jnp.float64)
term = diffrax.ODETerm(odes)
solver = diffrax.Dopri5()
controler = diffrax.PIDController(atol=1E-8,rtol=1E-8)
saveat = diffrax.SaveAt(t1=True)
sol = diffrax.diffeqsolve(term, solver, t0=0.0, t1=1.0, y0=y0, saveat=saveat, stepsize_controller=controler, dt0=None, max_steps=None, args=w) # Fixed t1=1.0
result = sol.ys
result = result.astype(jnp.float32)
return result
print(ode_fun(x1,x2)) # working
ode_fun_jit = jax.jit(ode_fun)
print(ode_fun_jit(x1, x2)) # exception see below
Exception raised:
raceback (most recent call last):
File "tmp.py", line 122, in <module>
print(ode_fun(x1,x2))
File "tmp.py", line 116, in ode_fun
sol = diffrax.diffeqsolve(term, solver, t0=0.0, t1=1.0, y0=y0, saveat=saveat, stepsize_controller=controler, dt0=None, max_steps=None, args=w) # Fixed t1=1.0
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 1337, in diffeqsolve
final_state, aux_stats = adjoint.loop(
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_adjoint.py", line 292, in loop
final_state = self._loop(
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 621, in loop
final_state = outer_while_loop(
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\contextlib.py", line 79, in inner
return func(*args, **kwds)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\loop.py", line 103, in while_loop
_, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\common.py", line 463, in new_body_fun
buffer_val2 = body_fun(buffer_val)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 618, in body_fun
new_state, _, _ = body_fun_aux(state)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 332, in body_fun_aux
(y, y_error, dense_info, solver_state, solver_result) = solver.step(
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 1151, in step
final_val = eqxi.while_loop(
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\loop.py", line 107, in while_loop
return checkpointed_while_loop(
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\checkpointed.py", line 247, in checkpointed_while_loop
body_fun_ = filter_closure_convert(body_fun_, init_val_)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\common.py", line 463, in new_body_fun
buffer_val2 = body_fun(buffer_val)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 855, in rk_stage
a_lower_i = t_map(lambda tab: tab[stage_index], tableaus_a_lower)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 604, in t_map
return jtu.tree_map(_fn, tableaus, *trees)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 602, in _fn
return fn(*_trees)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 855, in <lambda>
a_lower_i = t_map(lambda tab: tab[stage_index], tableaus_a_lower)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\array_methods.py", line 739, in op
return getattr(self.aval, f"_{name}")(self, *args)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\array_methods.py", line 352, in _getitem
return lax_numpy._rewriting_take(self, item)
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 6579, in _rewriting_take
if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None:
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 6563, in _attempt_rewriting_take_via_slice
arr = lax.dynamic_slice(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Operation creation failed
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "tmp.py", line 125, in <module>
print(ode_fun_jit(x1, x2))
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jaxlib\mlir\dialects\_stablehlo_ops_gen.py", line 2487, in dynamic_slice
return _get_op_result_or_op_results(DynamicSliceOp(operand=operand, start_indices=start_indices, slice_sizes=slice_sizes, loc=loc, ip=ip))
File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jaxlib\mlir\dialects\_stablehlo_ops_gen.py", line 2461, in __init__
super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))
ValueError: Operation creation failed
diffrax 0.6.0 jax 0.4.30
Using jax.experimental.enable_x64 and jit will return a exception.
Exception raised: