patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.45k stars 130 forks source link

jax.experimental.enable_x64 and jit #523

Open dv-ai opened 1 week ago

dv-ai commented 1 week ago

diffrax 0.6.0 jax 0.4.30

Using jax.experimental.enable_x64 and jit will return a exception.

import jax
import jax.numpy as jnp
import diffrax

key = jax.random.PRNGKey(0)
x1 = jax.random.uniform(key,(2,))
x2 = jax.random.uniform(key,(2,2))

def odes(t, y, w):
    result =  w @ y
    return result

# Define single solve function for fixed final time with scaling
def ode_fun(y0, w):
    with jax.experimental.enable_x64():
        y0 = y0.astype(jnp.float64)
        w = w.astype(jnp.float64)

        term = diffrax.ODETerm(odes)
        solver = diffrax.Dopri5()

        controler = diffrax.PIDController(atol=1E-8,rtol=1E-8)
        saveat = diffrax.SaveAt(t1=True)
        sol = diffrax.diffeqsolve(term, solver, t0=0.0, t1=1.0, y0=y0, saveat=saveat, stepsize_controller=controler, dt0=None, max_steps=None, args=w)  # Fixed t1=1.0
        result = sol.ys
        result = result.astype(jnp.float32)

    return result

print(ode_fun(x1,x2)) # working
ode_fun_jit = jax.jit(ode_fun)
print(ode_fun_jit(x1, x2)) # exception see below

Exception raised:

raceback (most recent call last):
  File "tmp.py", line 122, in <module>
    print(ode_fun(x1,x2))
  File "tmp.py", line 116, in ode_fun
    sol = diffrax.diffeqsolve(term, solver, t0=0.0, t1=1.0, y0=y0, saveat=saveat, stepsize_controller=controler, dt0=None, max_steps=None, args=w)  # Fixed t1=1.0
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 1337, in diffeqsolve
    final_state, aux_stats = adjoint.loop(
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_adjoint.py", line 292, in loop
    final_state = self._loop(
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 621, in loop
    final_state = outer_while_loop(
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\loop.py", line 103, in while_loop
    _, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\common.py", line 463, in new_body_fun
    buffer_val2 = body_fun(buffer_val)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 618, in body_fun
    new_state, _, _ = body_fun_aux(state)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_integrate.py", line 332, in body_fun_aux
    (y, y_error, dense_info, solver_state, solver_result) = solver.step(
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 1151, in step
    final_val = eqxi.while_loop(
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\loop.py", line 107, in while_loop
    return checkpointed_while_loop(
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\checkpointed.py", line 247, in checkpointed_while_loop
    body_fun_ = filter_closure_convert(body_fun_, init_val_)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\equinox\internal\_loop\common.py", line 463, in new_body_fun
    buffer_val2 = body_fun(buffer_val)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 855, in rk_stage
    a_lower_i = t_map(lambda tab: tab[stage_index], tableaus_a_lower)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 604, in t_map
    return jtu.tree_map(_fn, tableaus, *trees)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 602, in _fn
    return fn(*_trees)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\diffrax\_solver\runge_kutta.py", line 855, in <lambda>
    a_lower_i = t_map(lambda tab: tab[stage_index], tableaus_a_lower)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\array_methods.py", line 739, in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\array_methods.py", line 352, in _getitem
    return lax_numpy._rewriting_take(self, item)
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 6579, in _rewriting_take
    if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None:
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 6563, in _attempt_rewriting_take_via_slice
    arr = lax.dynamic_slice(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Operation creation failed

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "tmp.py", line 125, in <module>
    print(ode_fun_jit(x1, x2))
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jaxlib\mlir\dialects\_stablehlo_ops_gen.py", line 2487, in dynamic_slice
    return _get_op_result_or_op_results(DynamicSliceOp(operand=operand, start_indices=start_indices, slice_sizes=slice_sizes, loc=loc, ip=ip))
  File "\AppData\Local\Continuum\anaconda3\envs\jax_env\lib\site-packages\jaxlib\mlir\dialects\_stablehlo_ops_gen.py", line 2461, in __init__
    super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))
ValueError: Operation creation failed
patrick-kidger commented 1 week ago

AFAIK jax.experimental.enable_x64 is still fairly buggy/experimental, in part because of things like this. I don't think it's recommended for use.