fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.21k stars 233 forks source link

Question about the evolution of the membrame potential #536

Open slucas03 opened 2 months ago

slucas03 commented 2 months ago

Issue type

Description

Hi @fangwei123456. First of all, congratulate you for the great work you have done with SpikingJelly since I have been using it for more than a year and I have achieved great advances in my research.

Currently, I am working in order to improve my results with SNN and LIF neurons, however, I have many doubts about how the potential of each neuron evolves when a sample is introduced during the training process and I need to understand this evolution. Broadly speaking, the main sentences that I have used to implement a training are the following ones (which I took from one of your examples of a SNN):

  1. net.train()
  2. optimizer.zero_grad()
  3. SNN_output = net(train_data)
  4. loss = F.mse_loss(target_data, train_data)
  5. loss.backward()
  6. optimizer_step()
  7. function.reset_net(net)

The first sentence ("net.train()") activates the training process. The second sentence ("optimizer.zero_grad()") sets the gradients of all optimized torch.Tensor to zero. Here, it is my first question: Why is it necessary to set all values to zero each time a sample is entered? The third sentence ("SNN_output = net(train_data)") estimates the SNN output when a input is introduced. I know that in this sentence are included the functions of (a) neural_charge, (b) neuronal_fire, and (c) neuronal_reset, which are encoded in spikingjelly.clock_drive.neuron and described in https://github.com/fangwei123456/spikingjelly/blob/master/docs/source/activation_based_en/0_neuron.rst The fourth ("loss = F.mse_loss(target_data, train_data)") and fifth ("loss.backward()") sentences are based on estimating the loss function and backpropagating it, respectively. The sixth sentence ("optimizer_step()") updates the weights or the parameters of the SNN. Finally, in the seventh sentence ("function.reset_net(net)") you had a comment that is: "after optimizing the parameters once, the state of the network needs to be reset, because the SNN neurons have "memory"". This sentence raises many doubts in my mind. I mean, this comment only makes sense if you apply stateless neurons, in which at each time-step the membrame potential of the neurons only depends on the input. However, if you apply normal LIF neurons, I strongly believe that this sentence should not be used since you are resetting their membrame potential at each time-step regardless of whether they have fired a spike or not. I would like to know if I have made any mistakes in this reasoning. In addition, I do not understand why if I remove this sentence from the code the following bug appears:

"RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward."

Finally, I have one last question with reference to the membrame potential estimation done in the third sentence. Let's imagine that we are using LIF neurons described by tau=100, V_reset = 0 and V_threshold= 1. Moreover, we are using stateless neurons because we have written in the code the sentence of "function.reset_net(net)". If I am not wrong SpikingJelly uses the following equation to estimate the membrame potencial since V_reset=0: self.v += (x - self.v) / self.tau. In this equation, x - self.v = x because at the end of the previous time-step we have used "function.reset_net(net)", thus, the remaining membrame potencial in the current time-step is zero, isn't it? My question here is how does spikingjelly calculate this X?

Thank you in advance.

Ym-Shan commented 1 month ago

From the perspective of a single step mode, we will reset the neurons in the entire network after T (timesteps) inputs of a certain sample, rather than performing a reset operation at every moment.

In addition, regarding the following error you mentioned

RuntimeError: Trying to reward through the graph a second time (or directly access saved tensors after they have already been free) Saved intermediate values of the graph are free when you call. backward() or autograph. grad() Specify retained graph=True if you need to back through the graph a second time or if you need to access saved tensors after calling back

I think this is a common error, especially when conducting research on neuronal dynamics. I think the reason for this error is that you are using

Data2=data1

Instead of

Data2=data1. clone()

This leads to gradient problems.

slucas03 commented 1 month ago

Thank your for your answer.

But in the case of LIF neurons or any SNN neuron, theoretically it does not make sense to reset the neurons in the entire network, unless you want to use stateless neurons. In other words, LIF neurons are already described with soft and hard reset systems to reset their neuron potential when a spike is emitted. Hence, I do not understand why it is necessary tu use the sentences: (2) "optimizer.zero_grad()", and (7) "function.reset_net(net)".

Could you explain in detail what these sentences actually do?

Ym-Shan commented 1 month ago

I think a simple logic can solve your problem: when the second sample is inputted, the residual membrane potential of each neuron is meaningless, which is equivalent to noise for the second sample.

Ym-Shan commented 1 month ago

The function of

function.reset_net(net)

is to reset the membrane potential of each neuron to zero. The function of

optimizer.zero_grad()

is to reset the gradient to zero.

When performing gradient descent or other optimization algorithms, we need to update these parameters based on the gradient of parameters (such as the weights of neural networks). PyTorch accumulates gradients when performing backpropagation, which means that if not manually zeroed, gradients will accumulate from multiple backpropagation. This kind of accumulation is incorrect because each parameter update should only be based on the gradient obtained from the latest calculation.

slucas03 commented 1 month ago

Again, thank you for your quick answer.

I think a simple logic can solve your problem: when the second sample is inputted, the residual membrane potential of each neuron is meaningless, which is equivalent to noise for the second sample.

But why is the residual membrane potential of each neuron meaningless when you introduce the second sample? Broadly speaking, the charge phase of any SNN neuron depends on the membrame potential from the previous time-step and the spikes that reach to the neuron at the current time-step. Having said that, there may be cases in which at particular time-step the membrame potential of a neuron does not exceed the threshold value (e.g. v=0.8) and, thus, a spike is not emitted. But this residual potential is important for the following time-step because it will make this neuron more likely to emit a spike since its initial potential at the following time-step will be 0.8. This 0.8 cannot be considered as noise.

As I said before, the strategy of resetting neurons only makes sense if you are applying stateless neurons, but in the vast majority of the applications with SNN is neccesary to maintain the memory of the neurons. Hence, my final question is: is it possible with spikingjelly to keep the memory of the neurons and only reset their potential once the last sample has been introduced?

The function of

function.reset_net(net)

is to reset the membrane potential of each neuron to zero. The function of

optimizer.zero_grad()

is to reset the gradient to zero.

Thank you very much for the explanation. Very clear and concise answer. I like it so much. Hence, what I understand from your answer is that the "function.reset_net(net)" and "optimizer.zero_grad()" are totally independent of each other and their aims are totally different, aren't they?

Ym-Shan commented 1 month ago

Perhaps my words caused your misunderstanding. Below, I will give you an example based on the specific implementation (still using single step mode as an example):

for img, label in test_loader:
    for t in Time_steps:
        net(img[t])
    function.reset_net(net)        # Please pay attention to the position of this line

What I mean by the two input samples before and after refers to two completely different input events (or two encoded images), rather than two different input frames of the same event.

You are right, The goals of "function.reset_net(net)" and "optimizer.zero_grad()" are completely different, one for the Python automatic differentiation engine (provided by Python) and the other for the Spikingjelly framework.

slucas03 commented 1 month ago
for img, label in test_loader:
    for t in Time_steps:
        net(img[t])
    function.reset_net(net)        # Please pay attention to the position of this line

Yes, but here you are basing your answer on an image problem, such as https://github.com/fangwei123456/spikingjelly/blob/master/docs/source/activation_based_en/3_fc_mnist.rst, where each sample is not related with the next or previous one.

Now, imagine that you are in a regression problem, e.g. time-series problem, in which the samples (each value of the time-series) are related by time and, hence, resetting the neurons do not make any sense until the last sample is introduced and propagated throughout the network. In addition, you want to update the weights of the network with each sample propagated. The code scheme that you should follow would be:

 for epoch in range(train_epochs):
 net.train()
   for i in range(len(train_loader) - 1):
     train_sample = train_loader.dataset[i]
     optimizer.zero_grad()
     output = net(train_sample)
     loss = F.mse_loss(outptu, target)
     loss.backward()
     optimizer.step()
 functional.reset_net()net)

As you can see in the code, the resetting of the neurons must be done after introducing the last sample of the training set in order to take advantage of the potential of the LIF neurons. However, as I said in my first comment if you apply this scheme of code, the following error appears: "RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward"

Nonetheless, this error will not appear if you reset the neurons after the propagation of every sample, such as in the following code:

 for epoch in range(train_epochs):
 net.train()
   for i in range(len(train_loader) - 1):
     train_sample = train_loader.dataset[i]
     optimizer.zero_grad()
     output = net(train_sample)
     loss = F.mse_loss(outptu, target)
     loss.backward()
     optimizer.step()
     functional.reset_net()net)

In regression problems resetting the LIF neuron with every sample does not make sense since you are not taking avantage of the potential of LIF neurons and SNN. Hence, my question is: Why does this happen? and is it possible with spikingjelly not to reset the neurons with each sample and use a similar code to the first one I have written where it is only reset the neurons at the end of each epoch?

NneurotransmitterR commented 1 month ago

Hi! In your code len(train_loader) returns the number of batches in train_loader, however train_loader.dataset[i] returns one sample in your whole dataset. It would be incorret to access samples like that. SGD is based on a batch instead of a single sample. In each iteration within one epoch, a batch of samples are fed into the network and Gradient Descent is performed.

In his code, for img, label in test_loader retrieves a batch, then the network is updated on that batch and membrane potentials are reset after each batch, not each sample. The time dynamics of SNNs is simulated within a batch of samples.

Because the loss function itself is defined on a batch, it would be meaningless for the residual potentials to incorporate in next batch's SGD. That would bring noise to the loss. In a regression problem, there is only one batch, so the membrane potentials are reset after each epoch.

slucas03 commented 1 month ago

Hi! In your code len(train_loader) returns the number of batches in train_loader, however train_loader.dataset[i] returns one sample in your whole dataset. It would be incorret to access samples like that. SGD is based on a batch instead of a single sample. In each iteration within one epoch, a batch of samples are fed into the network and Gradient Descent is performed.

I have to notice that I am using a loss that is not based on rate encoding, it is based on time encoding. Hence, if your bach is equal to 1 (that is my case), it should be possible to access samples like I am doing. Then, SGD is based on using surrogate gradients funcion during backpropagation, so, theoretically, if you are able to calculate a loss/error, the bach value you have chosen has no influence during backpropagation. Or are you telling me that Spikingjelly does not support a batch of 1 during the backpropagation?

In a regression problem, there is only one batch, so the membrane potentials are reset after each epoch.

This is what I want to do but if I remove or modify the position of the neuron reset statement, the code does not work and I would like to know why.

NneurotransmitterR commented 1 month ago

Of course you can set the batch size to 1, and SpikingJelly should support that. Whatever the batch size, backpropagation is performed after forward propagation of the batch. If the membrane potentials introduced in the forward propagation of the last batch is not reset, they will influence the forward propagation of the current batch, hence the loss of the current batch is no longer calculated only on this batch. The residual information from the last batch is used in the current loss, that should be incorrect. If you are dealing with a regression problem with time encoding, I think you can define your own loss instead of using MSE with batchsize=1? In your code, have you tried debugging to see which line causes the error?

slucas03 commented 2 weeks ago

If the membrane potentials introduced in the forward propagation of the last batch is not reset, they will influence the forward propagation of the current batch.

But this is the hypothesis on which the Leaky Integrate-and-Fire neurons are based, isn't it?

The residual information from the last batch is used in the current loss, that should be incorrect.

The residual membrame potential from the last time-instant should influence in the current time-instant and, thus, it should influence in some way in the current lost. I mean, it seems like the LIF neurons applied in spikingjelly only work like LIF neurons within the samples inside the bach.

If you are dealing with a regression problem with time encoding, I think you can define your own loss instead of using MSE with batchsize=1? In your code, have you tried debugging to see which line causes the error?

I have designed and developed a loss function based on time that can be used with batchsize 1 and MSE, but its results is not 0 and 1 so it can be backpropagated. The problem is that if I delete the sentence of "functional.reset_net()" and I write it and the end of the epoch (resetting the neurons when all the samples of my dataset have been introduced to SNNs), the error will appear. However, if I maintain the sentence after "optimizer.step()" , no error will appear. I want do this:

 for epoch in range(train_epochs):
 net.train()
   for i in range(len(train_loader) - 1):
     train_sample = train_loader.dataset[i]
     optimizer.zero_grad()
     output = net(train_sample)
     loss = F.mse_loss(outptu, target)
     loss.backward()
     optimizer.step()
 functional.reset_net()net)

Perhaps, the solution is to introduce the condition of "retain_graph = True", but that would mean that using the default parameters of the spikingjelly examples the values of the membrame potential are not saved from one instant to another from the beginning. Therefore, if my assumption is correct, I reaffirm that the functioning of the LIF neurons implemented with spikingjelly is only correct within the samples contained in each batch.

NneurotransmitterR commented 2 weeks ago

But this is the hypothesis on which the Leaky Integrate-and-Fire neurons are based, isn't it?

Yes, the membrane potential of a LIF neuron is a hidden state evolving through time. However, since we are using BPTT, this hidden state is only meaningful within a batch. Imagine when reciting a poem, we will read it again and again, during which the memory is developed and finally we can recite the poem. But when given a different poem, we need to recite it again and the former memories would not help. However, I have not carried out experiments on the impact of not resetting the potentials on the network accuracy. For lack of information, I do not know how your algorithm works, could you please post your complete code here if possible, so that I can check the error?