brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
531 stars 94 forks source link

"TypeError: true_fun and false_fun output must have same type structure" in "brainpy.math.cond" function #465

Closed CloudyDory closed 1 year ago

CloudyDory commented 1 year ago

I am trying to implement a custom short-term synaptic depression model in Brainpy. A section of my code is trying to accomplish this task:

if bm.any(spike):
    x, x_old, d = func(spike, t, x, x_old, d)

However, Brainpy relies on jax, and if the variable spike is being traced at runtime, we cannot directly use it in Python control flow. I therefore try to use the following workaround:

import brainpy.math as bm

t = 10.0
x     = bm.Variable(bm.zeros(5, dtype=bm.float32))
x_old = bm.Variable(bm.zeros(5, dtype=bm.float32))
d     = bm.Variable(bm.ones(5, dtype=bm.float32))
spike = bm.Variable(bm.random.randn(5) > 0)

def do_nothing1(spike, t, x, x_old, d):
    return x, x_old, d

def func1(spike, t, x, x_old, d):
    x_old = bm.where(spike, x, x_old)
    x = bm.where(spike, t, x)
    d_next = 1.0 - (1.0 - d*0.9) * bm.exp(-(x-x_old)/100.0)
    d = bm.where(spike, d_next, d)
    return x, x_old, d

x.value, x_old.value, d.value = bm.cond(bm.any(spike), func1, do_nothing1, (t, spike.value, x.value, x_old.value, d.value))

But running this code generates the following error:

Traceback (most recent call last):

  File ~\miniconda3\Lib\site-packages\spyder_kernels\py3compat.py:356 in compat_exec
    exec(code, globals, locals)

  File d:\xxx\untitled2.py:27
    x.value, x_old.value, d.value = bm.cond(bm.any(spike), func1, do_nothing1, (t, spike.value, x.value, x_old.value, d.value))

  File ~\miniconda3\Lib\site-packages\brainpy\_src\math\object_transform\controls.py:539 in cond
    dyn_vars, rets = evaluate_dyn_vars(

  File ~\miniconda3\Lib\site-packages\brainpy\_src\math\object_transform\_tools.py:97 in evaluate_dyn_vars
    rets = jax.eval_shape(f2, *args, **kwargs)

  File ~\miniconda3\Lib\site-packages\jax\_src\traceback_util.py:166 in reraise_with_filtered_traceback
    return fun(*args, **kwargs)

  File ~\miniconda3\Lib\site-packages\jax\_src\api.py:2807 in eval_shape
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,

  File ~\miniconda3\Lib\site-packages\jax\_src\interpreters\partial_eval.py:670 in abstract_eval_fun
    _, avals_out, _ = trace_to_jaxpr_dynamic(

  File ~\miniconda3\Lib\site-packages\jax\_src\profiler.py:314 in wrapper
    return func(*args, **kwargs)

  File ~\miniconda3\Lib\site-packages\jax\_src\interpreters\partial_eval.py:2155 in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(

  File ~\miniconda3\Lib\site-packages\jax\_src\interpreters\partial_eval.py:2177 in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)

  File ~\miniconda3\Lib\site-packages\jax\_src\linear_util.py:188 in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File ~\miniconda3\Lib\site-packages\jax\_src\linear_util.py:188 in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File ~\miniconda3\Lib\site-packages\brainpy\_src\math\object_transform\controls.py:452 in call_fun
    return jax.lax.cond(pred, _true_fun, _false_fun, dyn_vars.dict_data(), *operands)

  File ~\miniconda3\Lib\site-packages\jax\_src\traceback_util.py:166 in reraise_with_filtered_traceback
    return fun(*args, **kwargs)

  File ~\miniconda3\Lib\site-packages\jax\_src\lax\control_flow\conditionals.py:292 in cond
    return _cond(*args, **kwargs)

  File ~\miniconda3\Lib\site-packages\jax\_src\lax\control_flow\conditionals.py:252 in _cond
    _check_tree_and_avals("true_fun and false_fun output",

  File ~\miniconda3\Lib\site-packages\jax\_src\lax\control_flow\common.py:198 in _check_tree_and_avals
    raise TypeError(

TypeError: true_fun and false_fun output must have same type structure, got PyTreeDef(({}, (CustomNode(Array[None], [*]), CustomNode(Array[None], [*]), CustomNode(Array[None], [*])))) and PyTreeDef(({}, (*, *, *))).

Yet, if I translate the code to use jax, it works fine:

import jax
import jax.numpy as jnp

t = 10.0
x     = jnp.zeros(5, dtype=jnp.float32)
x_old = jnp.zeros(5, dtype=jnp.float32)
d     = jnp.ones(5, dtype=jnp.float32)
spike = jax.random.normal(jax.random.PRNGKey(0), (5,)) > 0

@jax.jit
def do_nothing2(spike, t, x, x_old, d):
    return x, x_old, d

@jax.jit
def func2(spike, t, x, x_old, d):
    x_old = jnp.where(spike, x, x_old)
    x = jnp.where(spike, t, x)
    d_next = 1.0 - (1.0 - d*0.9) * jnp.exp(-(x-x_old)/50.0)
    d = jnp.where(spike, d_next, d)
    return x, x_old, d

x, x_old, d = jax.lax.cond(jnp.any(spike), func2, do_nothing2, t, spike, x, x_old, d)

I hope to know what is the cause of this error, and how to deal with it? I am using Brainpy version 2.4.4.post1, and jax version 0.4.14.

Thanks!

chaoming0625 commented 1 year ago

Thanks for the report.

I recommend using the following code:

import brainpy.math as bm

t = 10.0
x = bm.Variable(bm.zeros(5, dtype=bm.float32))
x_old = bm.Variable(bm.zeros(5, dtype=bm.float32))
d = bm.Variable(bm.ones(5, dtype=bm.float32))
spike = bm.Variable(bm.random.randn(5) > 0)

def do_nothing1(*args):
  return

def func1(t):
  x_old.value = bm.where(spike, x, x_old)
  x.value = bm.where(spike, t, x)
  d_next = 1.0 - (1.0 - d * 0.9) * bm.exp(-(x - x_old) / 100.0)
  d.value = bm.where(spike, d_next, d)

bm.cond(bm.any(spike), func1, do_nothing1, t)

If you are using Variable, please use it like this style.

CloudyDory commented 1 year ago

Thank you, your suggested solution works!