patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.38k stars 124 forks source link

64-bit mode Error #67

Closed adam-hartshorne closed 2 years ago

adam-hartshorne commented 2 years ago

When running any of the examples that put Jax into 64-bit mode e.g. the lotka_volterra benchmark, produces the following error, when using Python 3.9, jax 0.2.27, jaxlib 0.1.75, equinox 0.1.5, on a Windows 10 box.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
    ref_sol = diffeqsolve(
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\equinox\jit.py", line 90, in fun_wrapper
    dynamic_out, static_out = _filter_jit_cache(fun, **jitkwargs)(
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\equinox\jit.py", line 25, in f_wrapped
    out = f(*args, **kwargs)
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\integrate.py", line 808, in diffeqsolve
    final_state, aux_stats = adjoint.loop(
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\adjoint.py", line 76, in loop
    return self._loop_fn(**kwargs, is_bounded=True)
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\integrate.py", line 444, in loop
    final_state = bounded_while_loop(
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\misc\bounded_while_loop.py", line 135, in bounded_while_loop
    _, val, _ = _while_loop(_cond_fun, body_fun, init_data, rounded_max_steps, base)
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\misc\bounded_while_loop.py", line 245, in _while_loop
    return lax.scan(_scan_fn, data, xs=None, length=base)[0]
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\misc\bounded_while_loop.py", line 239, in _scan_fn
    return lax.cond(_unvmap_pred, _call, lambda x: x, _data), None
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\misc\bounded_while_loop.py", line 234, in _call
    return _while_loop(cond_fun, body_fun, _data, max_steps // base, base)
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\misc\bounded_while_loop.py", line 245, in _while_loop
    return lax.scan(_scan_fn, data, xs=None, length=base)[0]
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\misc\bounded_while_loop.py", line 239, in _scan_fn
    return lax.cond(_unvmap_pred, _call, lambda x: x, _data), None
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\misc\bounded_while_loop.py", line 234, in _call
    return _while_loop(cond_fun, body_fun, _data, max_steps // base, base)
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\misc\bounded_while_loop.py", line 245, in _while_loop
    return lax.scan(_scan_fn, data, xs=None, length=base)[0]
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\misc\bounded_while_loop.py", line 239, in _scan_fn
    return lax.cond(_unvmap_pred, _call, lambda x: x, _data), None
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\misc\bounded_while_loop.py", line 234, in _call
    return _while_loop(cond_fun, body_fun, _data, max_steps // base, base)
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\misc\bounded_while_loop.py", line 223, in _while_loop
    new_val = jax.tree_map(
  File "C:\Users\Adam\anaconda3\envs\jax39\lib\site-packages\diffrax\misc\bounded_while_loop.py", line 221, in _make_update
    return lax.select(pred, _new_val, _val)
TypeError: lax.select requires arguments to have the same dtypes, got int32, int64. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).

Process finished with exit code 1
patrick-kidger commented 2 years ago

Interesting. I assume this is OS-dependent behaviour as I don't get the same error on Linux. (Running under WSL2.)

How are you running JAX on Windows? Have you built jaxlib from source yourself?

adam-hartshorne commented 2 years ago

I have tried both jaxlib built by myself and also the from here https://github.com/cloudhan/jax-windows-builder

I should also note that I currently using Jax in x64 mode with the built-in ODE solver with no issues.

I have some vague recollection of encountering a similar problem with Jax itself way back over a similar issue and it might well have also been OS specific bug.

patrick-kidger commented 2 years ago

Right, I've tracked down the root cause: https://github.com/google/jax/issues/9574

This is something that can be worked around on our end -- I'll include a fix in the upcoming v0.0.3 release.

adam-hartshorne commented 2 years ago

Thanks for your fast reponse.

patrick-kidger commented 2 years ago

The version 0.0.3 release, which should fix this, is now available on PyPI.