[x] Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
Using jax.disable_jit() context manager can disable jit for jax.lax.scan, but cannot disable jit for bm.scan. This make it hard to debug functions inside bm.scan. See the following code for an example:
import jax
import jax.numpy as jnp
import brainpy.math as bm
def cumsum(res, el):
"""
- `res`: The result from the previous loop.
- `el`: The current array element.
"""
res = res + el
print(res)
return res, res # ("carryover", "accumulated")
a = jnp.array([1, 2, 3, 5, 7, 11, 13, 17])
result_init = 0
with jax.disable_jit():
final, result = jax.lax.scan(cumsum, result_init, a)
b = bm.array([1, 2, 3, 5, 7, 11, 13, 17])
result_init = 0
with jax.disable_jit():
final, result = bm.scan(cumsum, result_init, b)
Please:
Using
jax.disable_jit()
context manager can disable jit forjax.lax.scan
, but cannot disable jit forbm.scan
. This make it hard to debug functions insidebm.scan
. See the following code for an example:The printed output is: