fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.35k stars 239 forks source link

Propagation Pattern #236

Closed asudeeaydin closed 2 years ago

asudeeaydin commented 2 years ago

Hi,

I am slightly confused about the section on 'Propagation Pattern' and want to bring up several points.

In the step-by-step mode, for a non-zero input, you have non-zero states in the neurons at the last layer. In other words, the spike propagation of the input isn't implemented in discrete time steps.

In the layer-by-layer mode, the implementation is similar to neuromorphic chips or most of the other SNN frameworks where spike propagation is in discrete time steps. Meaning, that spikes that are caused by a non-zero input need to propagate through the network in order to have non-zero states in the next layers. More concretely, if we were to have a 10-layer architecture, and only present input to the network at t=0, we would start seeing neuronal activity at the 10th layer only at t=10.

Single-Step (N, C, W, H) input needs to be used for the step-by-step mode whereas the Multi-Step (T, N, C, W, H) input needs to be used for the layer-by-layer mode.

I hope my understanding is so far correct.

layerbylayer_comp_graph

Thanks a lot for this great work, and I really appreciate the quick and helpful responses!

fangwei123456 commented 2 years ago

Hi, you can refer to this https://spikingjelly.readthedocs.io/zh_CN/fuse_ms/activation_based_en/basic_concept.html for more details. It is the tutorial of the next version of SpikingJelly, and I guess it can answer your questions:

There are two dimensions in the computation graph of SNN, which are the time-step and the depth dimension. As the above figures show, the propagation of SNN is the building of the computation graph.We can find that the step-by-step propagation pattern is a Depth-First-Search (DFS) for traversing the computation graph, while the layer-by-layer propagation pattern is a Breadth-First-Search (BFS) for traversing the computation graph.

fangwei123456 commented 2 years ago

The attached drawing is what I would've expected.

I think you means "latency" in the actual physical system. If we run a SNN in a chip, denote M[i] as the layer i of the L layers, x[j] as the j-th of the input sequence x[0],x[1],...,x[T-1], and suppose each layer needs one second to process the data, the actual forward of step-by-step should be:

t = 0s, input x[0] to M[0]
t = 1s, M[0] outputs to M[1]
t = 2s, M[1] outputs to M[2]
...
t = Ls, M[L] outputs y[0]

The latency does exist, but it has no influence on the outputs because we still have y[0] = M[L](M[L-1](...M[0](x[0]))). When simulating the SNN, time-step means how many time we need to run the network, rather than the real time (latency) that a SNN needs in the real physical system. For example, if we use a SNN with T=4 to classify, then

 y[0] = M[L](M[L-1](...M[0](x[0])))
 y[1] = M[L](M[L-1](...M[0](x[1])))
 y[2] = M[L](M[L-1](...M[0](x[2])))
 y[3] = M[L](M[L-1](...M[0](x[3])))

where y[t] has shape [C] (ignore the batch dimension), and we use argmax(y[0]+y[1]+y[2]+y[3]) as the classification result.

asudeeaydin commented 2 years ago

Thanks a lot for your answer.

I do understand why the output of a layer-by-layer and step-by-step case would be the same. However, I'm still a bit confused about time steps and latency.

If latency still exists, then I would expect that it has an influence on the output.

For a 3-layer architecture (input, hidden & output layer), I would expect the following:

Given inputs x[t=0] = X, x[t=1] = 0, x[t=2] = 0,

y[t=0] = 0
y[t=1] = 0
y[t=2] != 0

But when I print the output layer activities for such a case, I get:

y[t=0] != 0
y[t=1] != 0
y[t=2] != 0

I'm using Spiking Jelly for a regression task, and since I have to supply a different input at every time step (eg. flow of events), I don't see how latency and time steps are different.

So basically, I expect the spike propagation to be in time like in a biological system, where the spike propagation is in discrete time steps. If this is not the case, then implementing the trained weights (using spiking jelly) of a network on a neuromorphic chip wouldn't yield the same result (in an ideal neuromorphic chip)?

fangwei123456 commented 2 years ago

I think you want to simulate latency in SpikingJelly. You can use the following codes:

import torch
import torch.nn as nn

class LatencyReadout(nn.Module):
    def __init__(self, latency:int):
        super().__init__()
        self.latency = latency
        self.buffer = []

    def forward(self, x: torch.Tensor):
        self.buffer.append(x)
        if self.buffer.__len__() < self.latency:
            return torch.zeros_like(x)
        else:
            return self.buffer.pop(0)

    def reset(self):
        self.buffer.clear()

T = 8
latency = 3

x_seq = torch.rand([T, 8])
net = nn.Sequential(
    nn.Linear(8, 4),
    nn.ReLU(),
    nn.Linear(4, 1),
    LatencyReadout(latency)
)

with torch.no_grad():
    for t in range(T + latency):
        if t < T:
            print(t, net(x_seq[t]))
        else:
            print(t, net(torch.zeros_like(x_seq[0])))
0 tensor([0.])
1 tensor([0.])
2 tensor([-0.0318])
3 tensor([-0.0772])
4 tensor([-0.0869])
5 tensor([0.0078])
6 tensor([0.0913])
7 tensor([-0.0771])
8 tensor([0.0370])
9 tensor([-0.0803])
10 tensor([-0.0164])
fangwei123456 commented 2 years ago

I'm using Spiking Jelly for a regression task, and since I have to supply a different input at every time step (eg. flow of events), I don't see how latency and time steps are different.

Latency also exists in CPU or GPU, but it is different from time-step. Here is an example:

import torch
import torch.nn as nn
import time
T = 8

x_seq = torch.rand([T, 8])
net = nn.Sequential(
    nn.Linear(8, 4),
    nn.ReLU(),
    nn.Linear(4, 1)
)

t0 = time.perf_counter()

with torch.no_grad():
    for t in range(T):
        t_input = time.perf_counter() - t0
        y = net(x_seq[t])
        t_output = time.perf_counter() - t0
        latency = t_output - t_input
        print(f'in time={t_input}, out time={t_output}, latency={latency}, time-step={t}, y[{t}]={y}')

Outputs are:

in time=3.5000000000451337e-06, out time=0.00023819999999996622, latency=0.00023469999999992108, time-step=0, y[0]=tensor([0.4120])
in time=0.0004418999999999951, out time=0.0005028999999999728, latency=6.099999999997774e-05, time-step=1, y[1]=tensor([0.4448])
in time=0.000669099999999978, out time=0.0007342000000000182, latency=6.510000000004013e-05, time-step=2, y[2]=tensor([0.5299])
in time=0.0009012000000000464, out time=0.0009556000000000564, latency=5.4400000000009996e-05, time-step=3, y[3]=tensor([0.5451])
in time=0.0011248000000000369, out time=0.0011790999999999885, latency=5.429999999995161e-05, time-step=4, y[4]=tensor([0.4139])
in time=0.0013461000000000167, out time=0.0014003999999999683, latency=5.429999999995161e-05, time-step=5, y[5]=tensor([0.4308])
in time=0.0015648999999999802, out time=0.0016199000000000074, latency=5.500000000002725e-05, time-step=6, y[6]=tensor([0.4535])
in time=0.001776300000000064, out time=0.0018297000000000452, latency=5.339999999998124e-05, time-step=7, y[7]=tensor([0.3866])

You can find that at time-step 0, the network needs latency=0.00023469999999992108 to calculate outputs. And if you can get the net's outputs before latency=0.00023469999999992108, you will get None.

asudeeaydin commented 2 years ago

I see, thank you! Really helpful!

asudeeaydin commented 2 years ago

Sorry for re-opening this issue, I want to be absolutely sure that I am understanding the correct thing.

So, you're saying that a single forward pass in spiking jelly is equivalent to simulating a real-time neuromorphic chip (like Loihi) T+1 timesteps where T is the depth of the network.

If my network is 3 layers, in Loihi I would get:

(Timestep, input, output)
t=0 x=X y=0
t=1 x=0 y=0
t=2 x=0 y=0
t=3 x=0 y=Y

Instead a single forward pass in spiking jelly gives me:

t=0 x=X y=Y
asudeeaydin commented 2 years ago

This would mean that neurons in spiking jelly should decay T timesteps worth in Loihi.

What I mean is the neuron potentials of Loihi at timestep t=3 should be equal to Spiking Jelly neurons at timestep t=0, the neuron potentials of Loihi at timestep t=6 should be equal to Spiking Jelly neurons at timestep t=1, and so on.

But in the code you provided this is not the case? Because once the output becomes non-zero, the membrane potentials in the 'latency' case follow the same membrane potentials as in the 'non-latency' case (eg. a normal SNN implementation in spiking jelly).

asudeeaydin commented 2 years ago

I'm assuming that the neurons are Leaky Integrate & Fire neurons.

Simply put, for a network of 3 layers, do spiking jelly neurons decay for 3 timesteps or 1 timestep for a single forward pass?

fangwei123456 commented 2 years ago

do spiking jelly neurons decay for 3 timesteps or 1 timestep for a single forward pass?

Neurons in SJ will decay for 1 time-steps. If you want the neuron to decay 3 time-steps, you can add a LatencyReadout with latency=1 behind each spiking neuron:

class LatencyReadout(nn.Module):
    def __init__(self, latency:int):
        super().__init__()
        self.latency = latency
        self.buffer = []

    def forward(self, x: torch.Tensor):
        self.buffer.append(x)
        if self.buffer.__len__() < self.latency:
            return torch.zeros_like(x)
        else:
            return self.buffer.pop(0)

    def reset(self):
        self.buffer.clear()

I think this latency layer works similar to that in Loihi (https://github.com/lava-nc/lava-dl/blob/4bc9c6ea950e9329138a775a477693c6d23d0a38/src/lava/lib/dl/slayer/axon/delay.py#L12).

But I think decay 3 time-steps and decay 1 time-steps should have the same outputs because during the first 2 time-steps, the LIF neuron does not get any input, and decay on v=0 has no influence.

fangwei123456 commented 2 years ago

@asudeeaydin I add the delay layer in SJ, which may be useful for you:

https://spikingjelly.readthedocs.io/zh_CN/latest/sub_module/spikingjelly.activation_based.layer.html#delay-init-en