jaxleyverse / jaxley

Differentiable neuron simulations with biophysical detail on CPU, GPU, or TPU.
https://jaxley.readthedocs.io
Apache License 2.0
57 stars 9 forks source link

Solution of HH depends on dt #149

Open coschroeder opened 1 year ago

coschroeder commented 1 year ago

I played around with dt of the stimulus and realized that the number of spikes in the standard HH model depends on the stepsize. Code to reproduce:

from neurax.channels import HHChannel

# Number of segments per branch.
nseg_per_branch = 1
comp = nx.Compartment()
branch = nx.Branch([comp for _ in range(nseg_per_branch)])

# point neuron:
cell = nx.Cell([branch for _ in range(1)], parents=jnp.asarray([-1]))

cell.set_params("length", 96) 
cell.set_params("radius", 96) 

cell.insert(HHChannel())

# Stimulus.
i_delay = 20.0  # ms
i_amp = 3.7  # nA 0.08
i_dur = 20.0  # ms

# Duration and step size.
dt = 0.0025  # ms # Changing dt here, changes the HH solution
t_max = 60.0  # ms

time_vec = jnp.arange(0.0, t_max+dt, dt)

stims = [nx.Stimulus(0, 0, 0.0, nx.step_current(i_delay, i_dur, i_amp, time_vec))]
recs = [nx.Recording(0, 0, 0.0)]

# Solve HH
s = nx.integrate(cell, stims, recs)

plt.figure(1,  figsize=(6, 4))
ax = plt.subplot(211)
_ = ax.plot(time_vec, s.T[:-1])
_ = ax.set_ylim([-90, 130])
_ = ax.set_ylabel("Voltage (mV)")

ax = plt.subplot(212)
ax.plot(time_vec,stims[0].current)
_ = ax.set_xlabel("Time (ms)")
_ = ax.set_ylabel("Stimulus (nA)")

Am I doing something wrong here? Or is there a bug in the solver?

michaeldeistler commented 1 year ago

Passing dt to the solver fixes the issue:

s = nx.integrate(cell, stims, recs, delta_t=dt)