Open madara-1112 opened 3 weeks ago
And I may also encountered another bug. I tried to comulate the firing rate of HH neuron using the following code. When I set the strength of input to be 0.280 and duration to be 1000, the figure was obviously incorrectly. And when I set the strength to 0.278 or 0.280, they both worked successfully as follow.
`import brainpy as bp import brainpy.math as bm import numpy as np import matplotlib.pyplot as plt
class HH(bp.dyn.NeuGroup):
def __init__(self,size,ENa=50.,gNa=1.2,EK=-77.
,gK=0.36,EL=-54.387,gL=0.003,V_th=0.,C=0.01):
super(HH,self).__init__(size=size)
self.ENa=ENa
self.EK=EK
self.EL=EL
self.gNa=gNa
self.gK=gK
self.gL=gL
self.C=C
self.V_th=V_th
self.V = bm.Variable(-65 * bm.ones(self.num))
self.m = bm.Variable(0.0529 * bm.ones(self.num))
self.h = bm.Variable(0.5961 * bm.ones(self.num))
self.n = bm.Variable(0.3177 * bm.ones(self.num))
self.gNa_=bm.Variable(0 * bm.ones(self.num))
self.gK_ = bm.Variable(0 * bm.ones(self.num))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num,dtype=bool))
self.t_last_spike=bm.Variable(bm.ones(self.num)*-1e7)
self.intergral = bp.odeint(f=self.derivative,method='exp_auto')
@property
def derivative(self):
return bp.JointEq(self.dV, self.dm, self.dh, self.dn)
def dm(self, m, t, V):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
beta = 4.0 * bm.exp(-(V + 65) / 18)
dmdt = alpha * (1 - m) - beta * m
return dmdt
def dh(self, h, t, V):
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
dhdt = alpha * (1 - h) - beta * h
return dhdt
def dn(self, n, t, V):
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta = 0.125 * bm.exp(-(V + 65) / 80)
dndt = alpha * (1 - n) - beta * n
return dndt
def dV(self, V, t, h, n, m):
I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
I_K = (self.gK * n ** 4.0) * (V - self.EK)
I_leak = self.gL * (V - self.EL)
dVdt = (- I_Na - I_K - I_leak + self.input) / self.C
return dVdt
def update(self, tdi):
t, dt = tdi.t, tdi.dt
V, m, h, n = self.intergral(self.V, self.m, self.h, self.n, t, dt=dt)
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.t_last_spike.value = bm.where(self.spike, t, self.t_last_spike)
self.V.value = V
self.m.value = m
self.h.value = h
self.n.value = n
self.gNa_.value=self.gNa * m ** 3.0 * h #记录钠电导变化
self.gK_.value=self.gK * n ** 4.0 #记录钾电导变化
self.input[:] = 0. # 重置神经元接收到的输入
currents, length = bp.inputs.section_input(values=[0.280], durations=[1000],returnlength=True) hh = HH(1) runner = bp.dyn.DSRunner(hh,monitors=['V', 'm', 'h', 'n','gNa','gK_'], inputs=['input', currents, 'iter']) runner.run(length) fig,axe=plt.subplots(2,1) axe[0].plot(runner.mon.ts, runner.mon.V,linewidth=2) print(runner.mon.V.shape) axe[0].set_ylabel('V(mV)')
axe[1].plot(runner.mon.ts, runner.mon.m,linewidth=2,color='blue',label='m') axe[1].plot(runner.mon.ts, runner.mon.n,linewidth=2,color='red',label='n') axe[1].plot(runner.mon.ts, runner.mon.h,linewidth=2,color='orange',label='h') axe[1].set_ylabel('Channel') plt.legend() plt.tight_layout() plt.show()`
The version I use is brainpy 2.4.5
Thank you for your question. Indeed, in the documentation, it should be corrected from "post_val[post_ids[i]] += values" to "post_val[post_ids[j]] += values". We will make the necessary changes to the documentation as soon as possible.
Regarding your second question, I am not quite clear about it. Could you provide the problematic code and specify which parameters should be modified to address the issue? Perhaps you could try upgrading brainpy to the latest version?
pip install brainpy -U
Thank you for answering my first question! As for the second question, the original code are attached. I'm not sure what caused the problem. I tried to retell the problem I encountered: When my current strength was set to 0.280, there was an obvious error in the image, which only showed the first 5000 time steps, even though I had set the model to simulate 10000 time steps. However, when I set the current strength to some other value such as 0.278, the image correctly showed all time steps. I don't know why that current strength value is so special to cause the bug. code.txt
I adjusted the time step of the model simulation and found that when the current strength is set to 0.280, the image could only show the first 5000 time points at most, no matter how long the time points I set was. That was, It could showed correctly all time steps if the time steps were set under 5000.
I might also try upgrading brainpy😂
It seems I know what the issue is. Is your JAX version above 0.4.32? JAX introduced an asynchronous CPU scheduling mechanism in version 0.4.32, which can cause runner.run()
to return prematurely and allow the subsequent code to execute. You can consider downgrading JAX to a version below 0.4.31, or change runner.run(length)
to jax.block_until_ready(runner.run(length))
.
Could you please provide the specific hardware information of your device? It seems that my device has difficulty reproducing the error.
The JAX version (also jaxlib) I used is 0.4.16. I ran this code on MacBook Pro, M2, macOS Sequoia 15.2.
I changed runner.run(length)
to jax.block_until_ready(runner.run(length))
but it did not work either.
@ztqakita
I wanted to see how firing rate varies with the current strength, and it output like this😂.
duration=1000 I=np.arange(0,0.5,0.002) group=HH(len(I)) runner = bp.dyn.DSRunner(group, monitors=['spike'], inputs=['input', I]) runner(duration=duration) F=runner.mon.spike.sum(axis=0)/(duration/1000) print(F) plt.plot(I,F,linewidth=2) plt.xlabel('I(mA/mm^2)') plt.ylabel('F(Hz)') plt.title('firing rate vs current') plt.show()
I printed the value of V[5000:10000] when the current strength was set at 0.280 and found that they were all NaN
There are two places in your code that can be fixed with a few changes, as shown below:
You can replace bm.exp
with bm.exprel
to avoid the NaN problem. When x
is near zero, exp(x)
is near 1, so the numerical calculation of exp(x) - 1
can suffer from catastrophic loss of precision. exprel(x)
is implemented to avoid the loss of precision that occurs when x
is near zero.
Thank you for your reply! It seems that I must update the version of brainpy to use this funciton, and I'll try it.
I updated brainpy and jax, and replaced bm.exp
with bm.exprel
as you did. It did work at current strength = 0.280, but i am not sure it also worked at other values of current strength, which i never encounter before replacing the function, with the firing rates shown below.
I set the current strength to zero, and was surprised to find that the HH neuron still produced spikes.
When I went back to the orginal function bm.exp
, it behaved just like before, working well at other values of current strength except 0.280, which ruled out the potential explanation of issues caused by packages updates.
Hello, I have a question about the function pre2post_event_sum. The documentation says "The pre-to-post event-driven synaptic summation with CSR synapse structure. When values is a scalar, this function is equivalent to post_val = np.zeros(post_num) post_ids, idnptr = pre2post for i in range(pre_num): if events[i]: for j in range(idnptr[i], idnptr[i+1]): post_val[post_ids[i]] += values" But I wonder if “post_val[post_ids[i]] += values” was written incorrectly and should be changed to "post_val[post_ids[j]] += values".