brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
537 stars 94 forks source link

Neural mass model demo output errors and get different result compared to the documentation #432

Closed CloudyDory closed 1 year ago

CloudyDory commented 1 year ago

Please:

Hi, I have installed BrainPy (2.4.3.post3) yesterday and I am following the demos in the quick start guide . The first two demos (EI balance network and decision network) runs fine, but the third neural mass model errors and produces different result.

Here is my code:

import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt

print(bp.__version__)
bm.set_platform('cpu')

# @-dimensional Wilson-Cowan model
wc = bp.rates.WilsonCowanModel(2,
                               wEE=16., wIE=15., wEI=12., wII=3.,
                               E_a=1.5, I_a=1.5, E_theta=3., I_theta=3.,
                               method='exp_euler_auto',
                               x_initializer=bm.asarray([-0.2, 1.]),
                               y_initializer=bm.asarray([0.0, 1.]))

runner = bp.DSRunner(wc, monitors=['x', 'y'], inputs=['input', -0.5])
runner.run(10.)

fig, gs = bp.visualize.get_figure(1, 2, 4, 3)
ax = fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.x, plot_ids=[0, 1], legend='e', ax=ax)
ax = fig.add_subplot(gs[0, 1])
bp.visualize.line_plot(runner.mon.ts, runner.mon.x, plot_ids=[0, 1], legend='i', ax=ax, show=True)

bf = bp.analysis.Bifurcation2D(
  wc,
  target_vars={'x': [-0.2, 1.], 'y': [-0.2, 1.]},
  target_pars={'x_ext': [-2, 2]},
  pars_update={'y_ext': 0.},
  resolutions={'x_ext': 0.01}
)
bf.plot_bifurcation()
bf.plot_limit_cycle_by_sim(duration=500)
bf.show_figure()

Here is the command line output and error trace:

2.4.3.post3
Predict 100 steps: : 100%|██████████| 100/100 [00:00<00:00, 1236.68it/s]
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
    There are 40000 candidates
I am trying to filter out duplicate fixed points ...
Traceback (most recent call last):

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/core.py:697 in __getattr__
    attr = getattr(self.aval, name)

AttributeError: 'ShapedArray' object has no attribute 'split'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/spyder_kernels/py3compat.py:356 in compat_exec
    exec(code, globals, locals)

  File ~/Disks/D/Users/Yunhui/Project/BrainPy/Neural_Mass_Model_Demo.py:40
    bf.plot_bifurcation()

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py:281 in plot_bifurcation
    jacobians = np.asarray(self.F_vmap_jacobian(jnp.asarray(final_fps), *final_pars.T))

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/traceback_util.py:166 in reraise_with_filtered_traceback
    return fun(*args, **kwargs)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/pjit.py:253 in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/pjit.py:161 in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/api.py:324 in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/pjit.py:491 in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/pjit.py:969 in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/linear_util.py:345 in memoized_fun
    ans = call(fun, *args)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/pjit.py:922 in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/profiler.py:314 in wrapper
    return func(*args, **kwargs)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155 in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177 in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/linear_util.py:188 in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/traceback_util.py:166 in reraise_with_filtered_traceback
    return fun(*args, **kwargs)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/api.py:1258 in vmap_f
    out_flat = batching.batch(

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/linear_util.py:188 in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/brainpy/_src/math/object_transform/autograd.py:218 in __call__
    rets = jax.eval_shape(

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/traceback_util.py:166 in reraise_with_filtered_traceback
    return fun(*args, **kwargs)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/api.py:2807 in eval_shape
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:670 in abstract_eval_fun
    _, avals_out, _ = trace_to_jaxpr_dynamic(

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/profiler.py:314 in wrapper
    return func(*args, **kwargs)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155 in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177 in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/linear_util.py:188 in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/linear_util.py:188 in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/brainpy/_src/math/object_transform/autograd.py:480 in jacfun
    jac = vmap(pullback)(_std_basis(y))

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/brainpy/_src/math/object_transform/autograd.py:461 in _std_basis
    return _unravel_array_into_pytree(pytree, 1, flat_basis)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/brainpy/_src/math/object_transform/autograd.py:451 in _unravel_array_into_pytree
    parts = arr.split(np.cumsum(safe_map(np.size, leaves[:-1])), axis)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/core.py:699 in __getattr__
    raise AttributeError(

AttributeError: DynamicJaxprTracer has no attribute split

And the output figure is different from the documentation:

Figure_1

Hope to know how to fix the issue. Thanks!

chaoming0625 commented 1 year ago

Hi, @CloudyDory , so many thanks for your report. The above error is caused by the compatible issue of the latest JAX, see #431. I have released brainpy==2.4.3.post4. It will solve this kind of issues.

CloudyDory commented 1 year ago

Hi @chaoming0625 Thank you for the fast reply! I have updated to 2.4.3.post4, the error is gone, but the output image is still the same, and different from the documentation.

chaoming0625 commented 1 year ago

The current results are correct. WilsonCowanModel has been revised and changed after 2.4.1.

You can verify your simulation results by inspecting the phase plane analysis.

bp.math.enable_x64()
pp = bp.analysis.PhasePlane2D(
  wc,
  target_vars={'x': [-0.2, 1.], 'y': [-0.2, 1.]},
  # target_pars={'x_ext': [-2, 2]},
  pars_update={'y_ext': 0., 'x_ext': -0.5},
  resolutions=0.001
)
pp.plot_nullcline()
pp.plot_fixed_point()
pp.plot_vector_field()
pp.show_figure()

With external input = -0.5, there is only a global stable attractor.

image

Therefore, all simulation will converge to point of [0., 0.]

CloudyDory commented 1 year ago

Thank you for the verification! I guess the result in quick start guide should be updated then.