brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
493 stars 90 forks source link

TypeError when running a simulation #593

Closed PikaPei closed 6 months ago

PikaPei commented 6 months ago

Hello!

I am new to BrainPy and find it a great simulation tool!

But when I played with it, I met some errors and couldn't find solutions in the documentation. I want to make two variables, var1 and var2, each receiving spike inputs from distinct SpikeTimeGroups. Each variable shows simple exponential decay dynamics with different time constants. I'm not sure if I'm doing something wrong, and I would appreciate any advice.

Here is my code and my BrainPy version is 2.4.6.post5,

import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt

class Model(bp.NeuGroup):
    def __init__(self, size, var1_tau=500, var2_tau=1000):
        super().__init__(size=size)
        self.var1_pre = bp.neurons.SpikeTimeGroup(1, times=[100], indices=[0])
        self.var2_pre = bp.neurons.SpikeTimeGroup(1, times=[200], indices=[0])

        self.var1_tau = var1_tau
        self.var1 = bm.Variable(bm.zeros(self.num))

        self.var2_tau = var2_tau
        self.var2 = bm.Variable(bm.zeros(self.num))

        self.integral = bp.odeint(bp.JointEq(self.dvar1, self.dvar2), method="exp_auto")

    def dvar1(self, var1, t):
        dvar1dt = -var1 / self.var1_tau
        return dvar1dt

    def dvar2(self, var2, t):
        dvar2dt = -var2 / self.var2_tau
        return dvar2dt

    def update(self):
        t = bp.share["t"]
        dt = bp.share["dt"]

        self.var1_pre.update()
        self.var2_pre.update()

        self.var1.value = self.integral(self.var1, t, dt=dt) + self.var1_pre.spike
        self.var2.value = self.integral(self.var2, t, dt=dt) + self.var2_pre.spike

    def run(self, duration):
        self.runner = bp.DSRunner(
            self,
            monitors=["var1", "var2"],
        )

        self.runner.run(duration)

if __name__ == "__main__":
    model = Model(1)
    model.run(1000)

    plt.plot(model.runner.mon.ts, model.runner.mon.var1)
    plt.plot(model.runner.mon.ts, model.runner.mon.var2)
    plt.show()

The error is:

Traceback (most recent call last):
  File "/Users/pei/project/comp-neuro-brainpy/test.py", line 49, in <module>
    model.run(1000)
  File "/Users/pei/project/comp-neuro-brainpy/test.py", line 44, in run
    self.runner.run(duration)
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/runners.py", line 512, in run
    return self.predict(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/runners.py", line 485, in predict
    outputs, hists = self._predict(indices, *inputs, shared_args=shared_args)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/runners.py", line 539, in _predict
    outs_and_mons = self._fun_predict(indices, *xs, shared_args=shared_args)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/runners.py", line 662, in _fun_predict
    return bm.for_loop(self._step_func_predict,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py", line 877, in for_loop
    rets = jax.eval_shape(transform, operands)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py", line 730, in call
    return jax.lax.scan(f=fun2scan,
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py", line 721, in fun2scan
    results = body_fun(*x, **unroll_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/runners.py", line 628, in _step_func_predict
    out = self.target(*x)
          ^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/dynsys.py", line 378, in __call__
    ret = self.update(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/dynsys.py", line 330, in _compatible_update
    return update_fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/project/comp-neuro-brainpy/test.py", line 35, in update
    self.var1.value = self.integral(self.var1, t, dt=dt) + self.var1_pre.spike
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/integrators/ode/base.py", line 114, in __call__
    new_vars = self.integral(**kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/integrators/ode/exponential.py", line 332, in integral_func
    r = f_integral(params_in[vps[0]], **{arg: params_in[arg] for arg in vps[1:] if arg in params_in})
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/integrators/ode/exponential.py", line 360, in integral
    linear, derivative = value_and_grad(*args, **kwargs)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/autograd.py", line 209, in __call__
    rets = self._transform(
           ^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/autograd.py", line 771, in grad_fun
    y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pei/.pyenv/versions/mambaforge-22.9.0-2/envs/brainpy-env/lib/python3.11/site-packages/brainpy/_src/math/object_transform/autograd.py", line 150, in _f_grad_without_aux_to_transform
    output = self.target(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Model.dvar1() missing 1 required positional argument: 't'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Thank you!

chaoming0625 commented 6 months ago

Thanks for the question. The error is caused because you are using the joint equation while giving the parameters independently.

self.integral = bp.odeint(bp.JointEq(self.dvar1, self.dvar2), method="exp_auto")

One way to solve this issue is modifying your update function as:

self.var1.value, self.var2.value = self.integral(self.var1, self.var2, t, dt=dt) 
self.var1 += self.var1_pre.spike
self.var2 += self.var2_pre.spike
PikaPei commented 6 months ago

I see. It works now! Thank you for the helpful answer!