brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
537 stars 94 forks source link

Question #692

Open madara-1112 opened 3 weeks ago

madara-1112 commented 3 weeks ago

Hello, I have a question about the function pre2post_event_sum. The documentation says "The pre-to-post event-driven synaptic summation with CSR synapse structure. When values is a scalar, this function is equivalent to post_val = np.zeros(post_num) post_ids, idnptr = pre2post for i in range(pre_num): if events[i]: for j in range(idnptr[i], idnptr[i+1]): post_val[post_ids[i]] += values" But I wonder if “post_val[post_ids[i]] += values” was written incorrectly and should be changed to "post_val[post_ids[j]] += values".

madara-1112 commented 3 weeks ago

And I may also encountered another bug. I tried to comulate the firing rate of HH neuron using the following code. When I set the strength of input to be 0.280 and duration to be 1000, the figure was obviously incorrectly. 205621730902081_ pic And when I set the strength to 0.278 or 0.280, they both worked successfully as follow. 205641730902101_ pic

`import brainpy as bp import brainpy.math as bm import numpy as np import matplotlib.pyplot as plt

class HH(bp.dyn.NeuGroup):

def __init__(self,size,ENa=50.,gNa=1.2,EK=-77.
             ,gK=0.36,EL=-54.387,gL=0.003,V_th=0.,C=0.01):
    super(HH,self).__init__(size=size)
    self.ENa=ENa
    self.EK=EK
    self.EL=EL
    self.gNa=gNa
    self.gK=gK
    self.gL=gL
    self.C=C
    self.V_th=V_th

    self.V = bm.Variable(-65 * bm.ones(self.num))
    self.m = bm.Variable(0.0529 * bm.ones(self.num))
    self.h = bm.Variable(0.5961 * bm.ones(self.num))
    self.n = bm.Variable(0.3177 * bm.ones(self.num))
    self.gNa_=bm.Variable(0 * bm.ones(self.num))
    self.gK_ = bm.Variable(0 * bm.ones(self.num))

    self.input = bm.Variable(bm.zeros(self.num))
    self.spike = bm.Variable(bm.zeros(self.num,dtype=bool))
    self.t_last_spike=bm.Variable(bm.ones(self.num)*-1e7)

    self.intergral = bp.odeint(f=self.derivative,method='exp_auto')

@property
def derivative(self):
    return bp.JointEq(self.dV, self.dm, self.dh, self.dn)

def dm(self, m, t, V):
    alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
    beta = 4.0 * bm.exp(-(V + 65) / 18)
    dmdt = alpha * (1 - m) - beta * m
    return  dmdt

def dh(self, h, t, V):
    alpha = 0.07 * bm.exp(-(V + 65) / 20.)
    beta = 1 / (1 + bm.exp(-(V + 35) / 10))
    dhdt = alpha * (1 - h) - beta * h
    return  dhdt

def dn(self, n, t, V):
    alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
    beta = 0.125 * bm.exp(-(V + 65) / 80)
    dndt = alpha * (1 - n) - beta * n
    return  dndt

def dV(self, V, t, h, n, m):

    I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
    I_K = (self.gK * n ** 4.0) * (V - self.EK)
    I_leak = self.gL * (V - self.EL)
    dVdt = (- I_Na - I_K - I_leak + self.input) / self.C

    return dVdt

def update(self, tdi):

    t, dt = tdi.t, tdi.dt
    V, m, h, n = self.intergral(self.V, self.m, self.h, self.n, t, dt=dt)
    self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
    self.t_last_spike.value = bm.where(self.spike, t, self.t_last_spike)
    self.V.value = V
    self.m.value = m
    self.h.value = h
    self.n.value = n

    self.gNa_.value=self.gNa * m ** 3.0 * h #记录钠电导变化
    self.gK_.value=self.gK * n ** 4.0 #记录钾电导变化
    self.input[:] = 0.  # 重置神经元接收到的输入

currents, length = bp.inputs.section_input(values=[0.280], durations=[1000],returnlength=True) hh = HH(1) runner = bp.dyn.DSRunner(hh,monitors=['V', 'm', 'h', 'n','gNa','gK_'], inputs=['input', currents, 'iter']) runner.run(length) fig,axe=plt.subplots(2,1) axe[0].plot(runner.mon.ts, runner.mon.V,linewidth=2) print(runner.mon.V.shape) axe[0].set_ylabel('V(mV)')

axe[1].plot(runner.mon.ts, runner.mon.gNa_,linewidth=2,color='blue')

axe[1].plot(runner.mon.ts, runner.mon.gK_,linewidth=2,color='red')

axe[1].set_ylabel('Conductance')

axe[1].plot(runner.mon.ts, runner.mon.m,linewidth=2,color='blue',label='m') axe[1].plot(runner.mon.ts, runner.mon.n,linewidth=2,color='red',label='n') axe[1].plot(runner.mon.ts, runner.mon.h,linewidth=2,color='orange',label='h') axe[1].set_ylabel('Channel') plt.legend() plt.tight_layout() plt.show()`

madara-1112 commented 3 weeks ago

The version I use is brainpy 2.4.5

Routhleck commented 2 weeks ago

Thank you for your question. Indeed, in the documentation, it should be corrected from "post_val[post_ids[i]] += values" to "post_val[post_ids[j]] += values". We will make the necessary changes to the documentation as soon as possible.

Regarding your second question, I am not quite clear about it. Could you provide the problematic code and specify which parameters should be modified to address the issue? Perhaps you could try upgrading brainpy to the latest version?

pip install brainpy -U
madara-1112 commented 2 weeks ago

Thank you for answering my first question! As for the second question, the original code are attached. I'm not sure what caused the problem. I tried to retell the problem I encountered: When my current strength was set to 0.280, there was an obvious error in the image, which only showed the first 5000 time steps, even though I had set the model to simulate 10000 time steps. However, when I set the current strength to some other value such as 0.278, the image correctly showed all time steps. I don't know why that current strength value is so special to cause the bug. code.txt

madara-1112 commented 2 weeks ago

I adjusted the time step of the model simulation and found that when the current strength is set to 0.280, the image could only show the first 5000 time points at most, no matter how long the time points I set was. That was, It could showed correctly all time steps if the time steps were set under 5000.

madara-1112 commented 2 weeks ago

I might also try upgrading brainpy😂

Routhleck commented 2 weeks ago

It seems I know what the issue is. Is your JAX version above 0.4.32? JAX introduced an asynchronous CPU scheduling mechanism in version 0.4.32, which can cause runner.run() to return prematurely and allow the subsequent code to execute. You can consider downgrading JAX to a version below 0.4.31, or change runner.run(length) to jax.block_until_ready(runner.run(length)).

Routhleck commented 2 weeks ago

Could you please provide the specific hardware information of your device? It seems that my device has difficulty reproducing the error.

madara-1112 commented 2 weeks ago

The JAX version (also jaxlib) I used is 0.4.16. I ran this code on MacBook Pro, M2, macOS Sequoia 15.2.

madara-1112 commented 2 weeks ago

I changed runner.run(length) to jax.block_until_ready(runner.run(length)) but it did not work either.

Routhleck commented 2 weeks ago

@ztqakita

madara-1112 commented 2 weeks ago

I wanted to see how firing rate varies with the current strength, and it output like this😂. duration=1000 I=np.arange(0,0.5,0.002) group=HH(len(I)) runner = bp.dyn.DSRunner(group, monitors=['spike'], inputs=['input', I]) runner(duration=duration) F=runner.mon.spike.sum(axis=0)/(duration/1000) print(F) plt.plot(I,F,linewidth=2) plt.xlabel('I(mA/mm^2)') plt.ylabel('F(Hz)') plt.title('firing rate vs current') plt.show()

Figure_1

I printed the value of V[5000:10000] when the current strength was set at 0.280 and found that they were all NaN

ztqakita commented 2 weeks ago

There are two places in your code that can be fixed with a few changes, as shown below: image You can replace bm.exp with bm.exprel to avoid the NaN problem. When x is near zero, exp(x) is near 1, so the numerical calculation of exp(x) - 1 can suffer from catastrophic loss of precision. exprel(x) is implemented to avoid the loss of precision that occurs when x is near zero.

madara-1112 commented 2 weeks ago

Thank you for your reply! It seems that I must update the version of brainpy to use this funciton, and I'll try it.

madara-1112 commented 2 weeks ago

I updated brainpy and jax, and replaced bm.exp with bm.exprel as you did. It did work at current strength = 0.280, but i am not sure it also worked at other values of current strength, which i never encounter before replacing the function, with the firing rates shown below. Figure_1 I set the current strength to zero, and was surprised to find that the HH neuron still produced spikes. Figure_1

madara-1112 commented 2 weeks ago

When I went back to the orginal function bm.exp, it behaved just like before, working well at other values of current strength except 0.280, which ruled out the potential explanation of issues caused by packages updates.