Closed CloudyDory closed 1 year ago
Thanks for the report.
I recommend using the following code:
import brainpy.math as bm
t = 10.0
x = bm.Variable(bm.zeros(5, dtype=bm.float32))
x_old = bm.Variable(bm.zeros(5, dtype=bm.float32))
d = bm.Variable(bm.ones(5, dtype=bm.float32))
spike = bm.Variable(bm.random.randn(5) > 0)
def do_nothing1(*args):
return
def func1(t):
x_old.value = bm.where(spike, x, x_old)
x.value = bm.where(spike, t, x)
d_next = 1.0 - (1.0 - d * 0.9) * bm.exp(-(x - x_old) / 100.0)
d.value = bm.where(spike, d_next, d)
bm.cond(bm.any(spike), func1, do_nothing1, t)
If you are using Variable
, please use it like this style.
Thank you, your suggested solution works!
I am trying to implement a custom short-term synaptic depression model in Brainpy. A section of my code is trying to accomplish this task:
However, Brainpy relies on jax, and if the variable
spike
is being traced at runtime, we cannot directly use it in Python control flow. I therefore try to use the following workaround:But running this code generates the following error:
Yet, if I translate the code to use jax, it works fine:
I hope to know what is the cause of this error, and how to deal with it? I am using Brainpy version 2.4.4.post1, and jax version 0.4.14.
Thanks!