jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.28k stars 2.78k forks source link

JaxStackTraceBeforeTransformation error with parametrized ODE #24253

Open SnowOwl-Hedwig opened 1 week ago

SnowOwl-Hedwig commented 1 week ago

Description

Hi everyone,

based on this tutorial I tried to get started with Jax and neural ODEs: https://colab.research.google.com/drive/1ZlK36VgWy1vBjBNXjSUg6Cb-7zeoa3jh

However, I get the a JaxStackTraceBeforeTransformation error (detailed error message below). I boiled down the code to a small working example (also provided below) and noted the error only occors when the equation in test_func contains an argument. Since this issue seemed similar to one raised in an earlier post (https://github.com/jax-ml/jax/issues/13629) I tried downgrading jax to version 0.4.23. I also tried setting up a fresh python environment with only the necessary packages installed. Nothing helped, so far. I'd appreciate your help :)

Working example:

from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5
import jax
import jax.numpy as jnp

def f(t, y, _):
  dp_dt = 0.9 * y
  return dp_dt

b0 = 2  # init condition
data_ts = jnp.linspace(0, 20, 100)
data_sol = diffeqsolve(ODETerm(f), Tsit5(), t0=0, t1=20, dt0=0.01,
                       y0=(b0), saveat=SaveAt(ts=data_ts))

def fwd_test(coeff):
    num_ts = 100
    def test_func(t, y, _coeff):
        dp_dt = y * _coeff #doesn't work
        # dp_dt = y #works
        return dp_dt

    b0 = 2
    model_ts = jnp.linspace(0, 20, num_ts)
    # Note: larger dt0 so that it runs faster; this is about as large as it can go
    model_sol = diffeqsolve(ODETerm(test_func), Tsit5(), t0=0, t1=20, dt0=0.5,
                        y0=(b0), args=coeff, saveat=SaveAt(ts=model_ts))
    model_b = model_sol.ys
    data_b = data_sol.ys
    return jnp.sum((model_b - data_b)**2)

coeff = 1.
grads = jax.grad(fwd_test)(coeff)

Error message:

JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()

File <frozen runpy>:88, in _run_code()

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel_launcher.py:18](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel_launcher.py#line=17)
     16 from ipykernel import kernelapp as app
---> 18 app.launch_new_instance()

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\traitlets\config\application.py:1075](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/traitlets/config/application.py#line=1074), in launch_instance()
   1074 app.initialize(argv)
-> 1075 app.start()

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelapp.py:739](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelapp.py#line=738), in start()
    738 try:
--> 739     self.io_loop.start()
    740 except KeyboardInterrupt:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\tornado\platform\asyncio.py:205](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/tornado/platform/asyncio.py#line=204), in start()
    204 def start(self) -> None:
--> 205     self.asyncio_loop.run_forever()

File [~\AppData\Local\Programs\Python\Python311\Lib\asyncio\base_events.py:607](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/asyncio/base_events.py#line=606), in run_forever()
    606 while True:
--> 607     self._run_once()
    608     if self._stopping:

File [~\AppData\Local\Programs\Python\Python311\Lib\asyncio\base_events.py:1919](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/asyncio/base_events.py#line=1918), in _run_once()
   1918     else:
-> 1919         handle._run()
   1920 handle = None

File [~\AppData\Local\Programs\Python\Python311\Lib\asyncio\events.py:80](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/asyncio/events.py#line=79), in _run()
     79 try:
---> 80     self._context.run(self._callback, *self._args)
     81 except (SystemExit, KeyboardInterrupt):

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:545](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=544), in dispatch_queue()
    544 try:
--> 545     await self.process_one()
    546 except Exception:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:534](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=533), in process_one()
    533         return
--> 534 await dispatch(*args)

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:437](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=436), in dispatch_shell()
    436     if inspect.isawaitable(result):
--> 437         await result
    438 except Exception:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\ipkernel.py:362](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/ipkernel.py#line=361), in execute_request()
    361 self._associate_new_top_level_threads_with(parent_header)
--> 362 await super().execute_request(stream, ident, parent)

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:778](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=777), in execute_request()
    777 if inspect.isawaitable(reply_content):
--> 778     reply_content = await reply_content
    780 # Flush output before sending the reply.

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\ipkernel.py:449](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/ipkernel.py#line=448), in do_execute()
    448 if accepts_params["cell_id"]:
--> 449     res = shell.run_cell(
    450         code,
    451         store_history=store_history,
    452         silent=silent,
    453         cell_id=cell_id,
    454     )
    455 else:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\zmqshell.py:549](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/zmqshell.py#line=548), in run_cell()
    548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3075](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3074), in run_cell()
   3074 try:
-> 3075     result = self._run_cell(
   3076         raw_cell, store_history, silent, shell_futures, cell_id
   3077     )
   3078 finally:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3130](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3129), in _run_cell()
   3129 try:
-> 3130     result = runner(coro)
   3131 except BaseException as e:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\async_helpers.py:129](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/async_helpers.py#line=128), in _pseudo_sync_runner()
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3334](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3333), in run_cell_async()
   3331 interactivity = "none" if silent else self.ast_node_interactivity
-> 3334 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3335        interactivity=interactivity, compiler=compiler, result=result)
   3337 self.last_execution_succeeded = not has_raised

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3517](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3516), in run_ast_nodes()
   3516     asy = compare(code)
-> 3517 if await self.run_code(code, result, async_=asy):
   3518     return True

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3577](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3576), in run_code()
   3576     else:
-> 3577         exec(code_obj, self.user_global_ns, self.user_ns)
   3578 finally:
   3579     # Reset our crash handler in place

Cell In[1], line 32
     31 coeff = 1.
---> 32 grads = jax.grad(fwd_test)(coeff)
     33 # print(grads)

Cell In[1], line 24, in fwd_test()
     23 # Note: larger dt0 so that it runs faster; this is about as large as it can go
---> 24 model_sol = diffeqsolve(ODETerm(test_func), Tsit5(), t0=0, t1=20, dt0=0.5,
     25                     y0=(b0), args=coeff, saveat=SaveAt(ts=model_ts))
     26 model_b = model_sol.ys

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\integrate.py:823](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/integrate.py#line=822), in diffeqsolve()
    819 #
    820 # Main loop
    821 #
--> 823 final_state, aux_stats = adjoint.loop(
    824     args=args,
    825     terms=terms,
    826     solver=solver,
    827     stepsize_controller=stepsize_controller,
    828     discrete_terminating_event=discrete_terminating_event,
    829     saveat=saveat,
    830     t0=t0,
    831     t1=t1,
    832     dt0=dt0,
    833     max_steps=max_steps,
    834     init_state=init_state,
    835     throw=throw,
    836     passed_solver_state=passed_solver_state,
    837     passed_controller_state=passed_controller_state,
    838 )
    840 #
    841 # Finish up
    842 #

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\adjoint.py:286](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/adjoint.py#line=285), in loop()
    285     msg = None
--> 286 final_state = self._loop(
    287     terms=terms,
    288     saveat=saveat,
    289     init_state=init_state,
    290     max_steps=max_steps,
    291     inner_while_loop=inner_while_loop,
    292     outer_while_loop=outer_while_loop,
    293     **kwargs,
    294 )
    295 if msg is not None:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\integrate.py:429](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/integrate.py#line=428), in loop()
    427 del filter_state
--> 429 final_state = outer_while_loop(
    430     cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers
    431 )
    433 def _save_t1(subsaveat, save_state):

File [~\AppData\Local\Programs\Python\Python311\Lib\contextlib.py:81](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/contextlib.py#line=80), in inner()
     80 with self._recreate_cm():
---> 81     return func(*args, **kwds)

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\loop.py:107](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/loop.py#line=106), in while_loop()
    106     del kind, base
--> 107     return checkpointed_while_loop(
    108         cond_fun,
    109         body_fun,
    110         init_val,
    111         max_steps=max_steps,
    112         buffers=buffers,
    113         checkpoints=checkpoints,
    114     )
    115 elif kind == "bounded":

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:247](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=246), in checkpointed_while_loop()
    246 cond_fun_ = jtu.tree_map(_stop_gradient, cond_fun_)
--> 247 body_fun_ = filter_closure_convert(body_fun_, init_val_)
    248 vjp_arg = (init_val_, body_fun_)

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\common.py:463](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/common.py#line=462), in new_body_fun()
    462 buffer_val = _wrap_buffers(val, pred, tag)
--> 463 buffer_val2 = body_fun(buffer_val)
    464 # Needed to work with `disable_jit`, as then we lose the automatic
    465 # ArrayLike->Array cast provided by JAX's while loops.
    466 # The input `val` is already cast to Array below, so this matches that.

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\integrate.py:219](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/integrate.py#line=218), in body_fun()
    214 #
    215 # Actually do some differential equation solving! Make numerical steps, adapt
    216 # step sizes, all that jazz.
    217 #
--> 219 (y, y_error, dense_info, solver_state, solver_result) = solver.step(
    220     terms,
    221     state.tprev,
    222     state.tnext,
    223     state.y,
    224     args,
    225     state.solver_state,
    226     state.made_jump,
    227 )
    229 # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
    230 # we get a negative value for y, and then get a NaN vector field. (And then
    231 # everything breaks.) See #143.

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\solver\runge_kutta.py:1041](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/solver/runge_kutta.py#line=1040), in step()
   1035 # Needs to be an `eqxi.while_loop` as:
   1036 # (a) we may have variable length: e.g. an FSAL explicit RK scheme will have one
   1037 #     more stage on the first step.
   1038 # (b) to work around a limitation of JAX's autodiff being unable to express
   1039 #     "triangular computations" (every stage depends on all previous stages)
   1040 #     without spurious copies.
-> 1041 final_val = eqxi.while_loop(
   1042     cond_stage,
   1043     rk_stage,
   1044     init_val,
   1045     max_steps=num_stages,
   1046     buffers=buffers,
   1047     kind="checkpointed" if self.scan_kind is None else self.scan_kind,
   1048     checkpoints=num_stages,
   1049     base=num_stages,
   1050 )
   1051 _, y1, f1_for_fsal, _, _, fs, ks, result = final_val

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\loop.py:107](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/loop.py#line=106), in while_loop()
    106     del kind, base
--> 107     return checkpointed_while_loop(
    108         cond_fun,
    109         body_fun,
    110         init_val,
    111         max_steps=max_steps,
    112         buffers=buffers,
    113         checkpoints=checkpoints,
    114     )
    115 elif kind == "bounded":

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:252](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=251), in checkpointed_while_loop()
    249 final_val_ = _checkpointed_while_loop(
    250     vjp_arg, cond_fun_, checkpoints, buffers_, max_steps
    251 )
--> 252 _, _, _, final_val = _stop_gradient_on_unperturbed(init_val_, final_val_, body_fun_)
    253 return final_val

JaxStackTraceBeforeTransformation: TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
  primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
  primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[1], line 32
     28     return jnp.sum((model_b - data_b)**2)
     31 coeff = 1.
---> 32 grads = jax.grad(fwd_test)(coeff)
     33 # print(grads)

    [... skipping hidden 10 frame]

Cell In[1], line 24, in fwd_test(coeff)
     22 model_ts = jnp.linspace(0, 20, num_ts)
     23 # Note: larger dt0 so that it runs faster; this is about as large as it can go
---> 24 model_sol = diffeqsolve(ODETerm(test_func), Tsit5(), t0=0, t1=20, dt0=0.5,
     25                     y0=(b0), args=coeff, saveat=SaveAt(ts=model_ts))
     26 model_b = model_sol.ys
     27 data_b = data_sol.ys

    [... skipping hidden 27 frame]

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1272](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1271), in _stop_gradient_on_unperturbed_jvp(***failed resolving arguments***)
   1268 del primals, tangents
   1269 perturb_val, perturb_body_fun = jtu.tree_map(
   1270     lambda _, t: t is not None, (init_val, body_fun), (t_init_val, t_body_fun)
   1271 )
-> 1272 perturb_val = _resolve_perturb_val(
   1273     init_val, body_fun, perturb_val, perturb_body_fun
   1274 )
   1275 t_final_val = jtu.tree_map(
   1276     _perturb_to_tang, t_final_val, perturb_val, is_leaf=_is_none
   1277 )
   1278 return final_val, t_final_val

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1241](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1240), in _resolve_perturb_val(final_val, body_fun, perturb_final_val, perturb_body_fun)
   1238         else:
   1239             perturb_val = jtu.tree_map(operator.or_, perturb_val, new_perturb_val)
-> 1241 perturb_val = jax.eval_shape(_resolve_perturb_val_impl).value
   1242 return perturb_val

    [... skipping hidden 12 frame]

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1214](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1213), in _resolve_perturb_val.<locals>._resolve_perturb_val_impl()
   1211     return _out
   1213 # Not `jax.jvp`, so as not to error if `body_fun` has any `custom_vjp`s.
-> 1214 jax.linearize(_to_linearize, dynamic)
   1215 if new_perturb_val is sentinel:
   1216     # `_dynamic_out` in `_to_linearize` had no JVP tracers at all, despite
   1217     # `_dynamic` having them. Presumably the user's `_body_fun` has no
   1218     # differentiable dependency whatsoever.
   1219     # This can happen if all the autograd is happening through
   1220     # `perturb_body_fun`.
   1221     return Static(perturb_val)

    [... skipping hidden 5 frame]

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1207](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1206), in _resolve_perturb_val.<locals>._resolve_perturb_val_impl.<locals>._to_linearize(_dynamic)
   1205 def _to_linearize(_dynamic):
   1206     _body_fun, _val = combine(_dynamic, static)
-> 1207     _out = _body_fun(_val)
   1208     _dynamic_out, _static_out = partition(_out, is_inexact_array)
   1209     _dynamic_out = _record_symbolic_zeros(_dynamic_out)

    [... skipping hidden 10 frame]

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\_src\custom_derivatives.py:351](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/jax/_src/custom_derivatives.py#line=350), in _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args)
    344     msg = ("Custom JVP rule must produce primal and tangent outputs with "
    345            "corresponding shapes and dtypes, but got:\n{}")
    346     disagreements = (
    347         f"  primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}"
    348         for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out)
    349         if av_et != av_t)
--> 351     raise TypeError(msg.format('\n'.join(disagreements)))
    352 yield primals_out + tangents_out, (out_tree, primal_avals)

TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
  primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
  primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.34 jaxlib: 0.4.34 numpy: 1.26.4 python: 3.11.1 (tags/v3.11.1:a7a450f, Dec 6 2022, 19:58:39) [MSC v.1934 64 bit (AMD64)] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1 platform: uname_result(system='Windows', release='10', version='10.0.19044', machine='AMD64')

jupyterlab: 4.2.2 diffrax: 0.4.1

dfm commented 1 week ago

The error reported here is actually a TypeError being raised because of an issue with the return types in a jax.custom_jvp. It's hard to see from this error report exactly which custom_jvp is the culprit, but it seems like it must be something within diffrax or equinox, so I'd recommend opening the issue on the https://github.com/patrick-kidger/diffrax issue tracker.

SnowOwl-Hedwig commented 1 week ago

ok, thanks for pointing this out. I'll try my luck there.