fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.32k stars 237 forks source link

Spike on positive and negative threshold #506

Closed JackCaster closed 5 months ago

JackCaster commented 6 months ago

Issue type

SpikingJelly version

0.0.0.0.14

Description

I would like to have a LIF neuron that can spike when the potential crosses a positive (+1) or negative (-1) threshold. I think I got the a custom LIF neuron to work: image

but my attempt to regress the membrane potential (by finding the k gain that is applied to the input current) fails image

and I suspect that the surrogate functions may not work as intended when the spike comes from a negative threshold. The training loss would keep oscillating without converging despite the learning rate (I do not have this problems with normal LIF neurons)

Do you know how I could get this to work?

Minimal code to reproduce the error/bug

#!/usr/bin/env python3
#!/usr/bin/env python3
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron

class MyLIF(neuron.BaseNode):
    def __init__(self, k=1.0, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.k = nn.Parameter(torch.tensor(k))

    def neuronal_charge(self, x):
        k = torch.sigmoid(self.k)
        self.v = self.v + x * k

    def neuronal_fire(self):
        vsign = torch.sign(self.v)
        value = vsign * (self.v - (vsign * self.v_threshold))
        return self.surrogate_function(value)

def train_model(model, xs, ys, lr, epochs):
    model.train()

    ps = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.AdamW(ps, lr=lr)
    criterion = nn.MSELoss(reduction="sum")

    print("Training network...")
    for i in range(epochs):
        model.reset()

        ys_hat = []

        for x in xs:
            model(x)
            ys_hat.append(model.v.unsqueeze(0))

        ys_hat = torch.concatenate(ys_hat)

        optimizer.zero_grad()
        loss = criterion(ys_hat, ys)
        loss.backward()

        optimizer.step()

        print(f"Epoch {i: 5d}", f" | Loss {loss.item():0.4f}")

    model.reset()
    return model.eval()

if __name__ == "__main__":

    import matplotlib.pyplot as plt
    import numpy as np

    torch.manual_seed(2024)
    np.random.seed(2024)

    dt = 0.01
    ts = torch.linspace(0.0, 1.0, 100)
    xs = torch.cat((torch.ones(50), -torch.ones(50))) * 0.1

    with torch.no_grad():
        model = MyLIF(k=1.5)
        model.reset()

        ys = []
        spks = []
        for x in xs:
            spk = model(x)
            ys.append(model.v.unsqueeze(0))
            spks.append(spk.unsqueeze(0))

        ys = torch.cat(ys)
        spks = torch.cat(spks)

    fig, ax = plt.subplots(2, 1, tight_layout=True)
    ax[0].plot(ts, xs)
    ax[0].set_title("Stimuli")

    ax[1].plot(ts, ys, label="True")
    ax[1].vlines(ts[spks.squeeze().bool()], -1.1, 1.1, ls=":", color="r")
    ax[1].set_title("Membrane potential")

    plt.show()

    model_hat = train_model(MyLIF(), xs=xs, ys=ys, lr=1e-2, epochs=100)
    model_hat.reset()

    with torch.no_grad():
        ys_hat = []
        spks_hat = []

        for x in xs:
            spk_hat = model_hat(x)

            ys_hat.append(model_hat.v.unsqueeze(0))
            spks_hat.append(spk_hat.unsqueeze(0))

        ys_hat = torch.concatenate(ys_hat)

    fig, ax = plt.subplots(2, 1, tight_layout=True)
    ax[0].plot(ts, xs)
    ax[0].set_title("Stimuli")

    ax[1].plot(ts, ys, label="True")
    ax[1].plot(ts, ys_hat, label="Estimate")
    ax[1].set_title("Membrane potential")

    plt.legend()
    plt.show()
fangwei123456 commented 6 months ago

You can try to define the neuron as this:

import torch
from spikingjelly.activation_based import neuron

from matplotlib import pyplot as plt

class ABSThresholdLIFNode(neuron.SimpleLIFNode):
    def neuronal_fire(self):
        return self.surrogate_function(torch.abs(self.v) - self.v_threshold)

T = 64
x = torch.cat((0.4 * torch.ones([T//2]), -0.4 * torch.ones([T//2])))

net = ABSThresholdLIFNode(tau=100., decay_input=False)

v = []
s_t = []
for t in range(T):
    s_t.append(net(x[t]) * t)
    v.append(net.v)

fig = plt.figure()
plt.subplot(2, 1, 1)
plt.plot(torch.arange(T), x, label='input')
plt.plot(torch.arange(T), v, label='v')
plt.legend()

plt.subplot(2, 1, 2)
plt.eventplot(s_t, label='spike', colors='red')
plt.legend()

plt.show()

image

JackCaster commented 5 months ago

Thanks!