jaxleyverse / jaxley

Differentiable neuron simulations with biophysical detail on CPU, GPU, or TPU.
https://jaxley.readthedocs.io
Apache License 2.0
49 stars 7 forks source link

Clamping of Synapse: Following Synapse might also be clamped #485

Open deezer257 opened 9 hours ago

deezer257 commented 9 hours ago

If I create a network where cells are connected via RibbonSynapses, and I apply a data clamp to one of the RibbonSynapses, not only is the current synapse clamped to the defined value, but also the next synapse in sequence (with an incremented index). Further, it doesn't matter which method I use to index the synapses (so if I use the pre or post synapsing indexing or the indexing via the edges). An example is attatched:

Jax version: 0.4.35

import matplotlib.pyplot as plt
import jaxley as jx
import jax
from jaxley_mech.synapses.ribbon import RibbonSynapse
from jaxley.connect import  connect
from jaxley_mech.channels.hodgkin52 import Leak, Na, K
import jax.numpy as jnp
import numpy as np

from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "gpu")

# Create a new dummy network
comp_opt = jx.Compartment()
branch_opt = jx.Branch(comp_opt, nseg=1)
cell_opt = jx.Cell([branch_opt], [-1])
net_opt = jx.Network(cells=[cell_opt] * 4)

# Insert the leak channel into the cell
net_opt.insert(Leak())
net_opt.insert(Na())
net_opt.insert(K())

v_rest = -70

# Set the resting potential of the cell and also to equilibrium of the
# leak channel to the resting potential, so that the overall resting
# potential of the cell is the resting potential
net_opt.set('v', v_rest)
net_opt.set('Leak_eLeak', v_rest)

# Define the pre, middle and post cell
first = net_opt.cell(0)
second = net_opt.cell(1)
third = net_opt.cell(2)
fourth = net_opt.cell(3)

# Connect the cells with the ribbon synapse
connect(first, second, RibbonSynapse(solver = "newton"))
connect(second, third, RibbonSynapse(solver = "newton"))
connect(third, fourth, RibbonSynapse(solver = "newton"))

net_opt.cell([0,1]).set('RibbonSynapse_V_half', v_rest/2)
net_opt.cell([1,2]).set('RibbonSynapse_V_half', v_rest/2)
net_opt.cell([2,3]).set('RibbonSynapse_V_half', v_rest/2)
net_opt.cell('all').set('RibbonSynapse_gS', 0)

# Parameters
time_max = 1000
dt = 1
time_steps = int((time_max + dt)  / dt)

# Time vector
time_vec = jnp.arange(0.0, time_max, dt)

# Generate a base signal with a Gaussian function
mean = (time_max / 2) - 100  # Center of the Gaussian
std_dev = 100  # Standard deviation of the Gaussian
base_signal = (np.exp(-0.5 * ((time_vec - mean) / std_dev) ** 2)) * 3 +2

# Get a 2d array of inputs by putting the base signal in the first column
# and the base signal shifted by 1 in the second column
inputs = jnp.array([time_vec, base_signal]).T

# Integrate the network without using vmap and jit
net_opt.delete_recordings()
net_opt.delete_stimuli()

#net_opt.cell([0,1]).record("RibbonSynapse_exo")
#net_opt.cell([1,2]).record("RibbonSynapse_exo")
#net_opt.cell([2,3]).record("RibbonSynapse_exo")

net_opt.RibbonSynapse.edge(0).record("RibbonSynapse_exo", verbose = False)
net_opt.RibbonSynapse.edge(1).record("RibbonSynapse_exo", verbose = False)
net_opt.RibbonSynapse.edge(2).record("RibbonSynapse_exo", verbose = False)

net_opt.cell(0).record("v")
net_opt.cell(1).record("v")
net_opt.cell(2).record("v")
net_opt.cell(3).record("v")

data_clamps = None
# Input are the y-values of the inputs
#data_clamps = net_opt.cell([0,1]).data_clamp("RibbonSynapse_exo", inputs[:,1], data_clamps = data_clamps)
data_clamps = net_opt.edge(0).data_clamp("RibbonSynapse_exo", inputs[:,1], data_clamps = data_clamps)

# Integrate the network
s = jx.integrate(net_opt, 
                data_clamps = data_clamps,
                solver = "bwd_euler")

fig, ax = plt.subplots(s.shape[0], 1, figsize=(10, 20))
# Increase space between subplots
plt.subplots_adjust(hspace=1)

# Loop over the subplots and plot each synapse
for i in range(s.shape[0]):
    ax[i].plot(s[i, :])
    if i < 3:
        ax[i].set_title(f"Synapse {i + 1}")
    else:
        ax[i].set_title(f"Mmebrance Voltage Cell {i - 3}")

image

jnsbck commented 2 hours ago

Hey, thanks for reporting this. Three things that would be really helpful to clarify first:

  1. are you on the most recent commit?
  2. Did you check if this also applies to other synapses, i.e. IonotropicSynapse? If this is not the case I suspect this might be an issue in jaxley_mech.
  3. If 1. and 2. are not the case, could you come up with a more minimal example, i.e. rm everything that's not strictly necessary but still produces this behavior. That would be really helpful to debug this.
deezer257 commented 1 hour ago
  1. I git cloned the repository from the main branch and pip installed it in the editable mode on 07.11.2024 at 16:00, so yes. Still there was the same problem with the RibbonSynapse.

  2. I tried to clamp the parameters of the IonotropicSynapse, but the library doesn't accept that in the data_clamp:

  3. A shorter MWE for IonotropicSynapse:

    
    import matplotlib.pyplot as plt
    import jaxley as jx
    import jax
    from jaxley.synapses import IonotropicSynapse
    from jaxley.connect import connect
    from jaxley_mech.channels.hodgkin52 import Leak, Na, K
    import jax.numpy as jnp
    import numpy as np

Create a new dummy network

comp_opt = jx.Compartment() branch_opt = jx.Branch(comp_opt, nseg=1) cell_opt = jx.Cell([branch_opt], [-1]) net_opt = jx.Network(cells=[cell_opt] * 4)

Insert the leak channel into the cell

net_opt.insert(Leak()) net_opt.insert(Na()) net_opt.insert(K())

Connect the cells with the synapses

connect(net_opt.cell(0), net_opt.cell(1), IonotropicSynapse()) connect(net_opt.cell(1), net_opt.cell(2), IonotropicSynapse()) connect(net_opt.cell(2), net_opt.cell(3), IonotropicSynapse())

Set the conductance of the synapse to zero

net_opt.cell('all').set('IonotropicSynapse_gS', 0)

inputs = jnp.ones(1000)

net_opt.delete_recordings() net_opt.delete_stimuli()

Record the conductance of the synapses

net_opt.cell([0,1]).record("IonotropicSynapse_gS") net_opt.cell([1,2]).record("IonotropicSynapse_gS") net_opt.cell([2,3]).record("IonotropicSynapse_gS")

Clamp the cell

data_clamps = None data_clamps = net_opt.cell([0,1]).data_clamp("IonotropicSynapse_gS", inputs, data_clamps = data_clamps)

Integrate the network

s = jx.integrate(net_opt, data_clamps = data_clamps, solver = "bwd_euler")

fig, ax = plt.subplots(2, 1, figsize=(5, 5)) ax[0].plot(s[0, :]) ax[1].plot(s[1, :])


This always yields the error, which doesn't make so much sense for me since I was able to set IonotropicSynapse_gS to 0 before:
![image](https://github.com/user-attachments/assets/0047791c-5170-4536-baca-c986428c0653)

I did exactly the same with the RibbonSynapse and there the data_clamp worked.

4. This would be the shorte MWE for the RibbonSynapse:
```python
import matplotlib.pyplot as plt
import jaxley as jx
import jax
from jaxley_mech.synapses.ribbon import RibbonSynapse
from jaxley.connect import  connect
from jaxley_mech.channels.hodgkin52 import Leak, Na, K
import jax.numpy as jnp
import numpy as np

from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "gpu")

# Create a new dummy network
comp_opt = jx.Compartment()
branch_opt = jx.Branch(comp_opt, nseg=1)
cell_opt = jx.Cell([branch_opt], [-1])
net_opt = jx.Network(cells=[cell_opt] * 4)

# Insert the leak channel into the cell
net_opt.insert(Leak())
net_opt.insert(Na())
net_opt.insert(K())

# Connect the cells with the ribbon synapses
connect(net_opt.cell(0), net_opt.cell(1), RibbonSynapse(solver = "newton"))
connect(net_opt.cell(1), net_opt.cell(2), RibbonSynapse(solver = "newton"))
connect(net_opt.cell(2), net_opt.cell(3), RibbonSynapse(solver = "newton"))

# Set the conductance of the synapse to zero
net_opt.cell('all').set('RibbonSynapse_gS', 0)

inputs = jnp.ones(1000)

net_opt.delete_recordings()
net_opt.delete_stimuli()

# Record the conductance of the synapses
net_opt.cell([0,1]).record("RibbonSynapse_exo")
net_opt.cell([1,2]).record("RibbonSynapse_exo")
net_opt.cell([2,3]).record("RibbonSynapse_exo")

# Clamp the cell
data_clamps = None
data_clamps = net_opt.cell([0,1]).data_clamp("RibbonSynapse_exo", inputs, data_clamps = data_clamps)

# Integrate the network
s = jx.integrate(net_opt, 
                data_clamps = data_clamps,
                solver = "bwd_euler")

fig, ax = plt.subplots(2, 1, figsize=(5, 5))
ax[0].plot(s[0, :])
ax[1].plot(s[1, :])

The same problem arose, with the data clamped synapse. The figure depicts the recorded parameters of the synapse. image