brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
493 stars 90 forks source link

Using `jax.disable_jit()` cannot disable jit for `bm.scan` #605

Closed CloudyDory closed 5 months ago

CloudyDory commented 5 months ago

Please:

Using jax.disable_jit() context manager can disable jit for jax.lax.scan, but cannot disable jit for bm.scan. This make it hard to debug functions inside bm.scan. See the following code for an example:

import jax
import jax.numpy as jnp
import brainpy.math as bm

def cumsum(res, el):
    """
    - `res`: The result from the previous loop.
    - `el`: The current array element.
    """
    res = res + el
    print(res)
    return res, res  # ("carryover", "accumulated")

a = jnp.array([1, 2, 3, 5, 7, 11, 13, 17])
result_init = 0
with jax.disable_jit():
    final, result = jax.lax.scan(cumsum, result_init, a)

b = bm.array([1, 2, 3, 5, 7, 11, 13, 17])
result_init = 0
with jax.disable_jit():
    final, result = bm.scan(cumsum, result_init, b)

The printed output is:

1
3
6
11
18
29
42
59
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
chaoming0625 commented 5 months ago

This is a great issue. I add a new PR to fix this #606

chaoming0625 commented 5 months ago

This is a great issue. I add a new PR to fix this #606

chaoming0625 commented 5 months ago

Has been solved in #606