fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.35k stars 239 forks source link

Question: Meaning of the hyperparamenter T #186

Closed slucas03 closed 2 years ago

slucas03 commented 2 years ago

I am working with the next example: https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/clock_driven/examples/lif_fc_mnist.py .

In this example there is a hyperparameter called T. I am not very sure but I related the hyperparameter "T" with the SNN runtime. I have noticed that the concept of this execution time is very abstract and I have tried to find a clear definition and explanation of its use in papers and forums and everyone tries to avoid this issue.

In the aforomentioned example this hyperparameter is used in the following lines of code:

for t in range(T): 
            if t == 0: 
                out_spikes_counter = net(encoder(img).float())
            else:
                out_spikes_counter += net(encoder(img).float()) 

Where img is an array whose shape is [batch_size,1, 28, 28] and out_spikes_counter is another array with dimension [batch_size, number_output]. I do not understand very well what the role of T is. What T is used for? Is it related with BackPropagation Through Time (BPTT)? Or is it as if the same input is being fed into the SNN (and propagating forward) T times?

I have also noticed that StereoSpike that is based in SpikingJelly does not use this hyperparameter (https://github.com/urancon/StereoSpike) , so I do not know whether T is an important hyperparameter or not.

Thank you in advance.

fangwei123456 commented 2 years ago

Is it related with BackPropagation Through Time (BPTT)?

Yes. The SNN is trained with BPTT, and T is the total time-steps.

Or is it as if the same input is being fed into the SNN (and propagating forward) T times?

If you use a sequential as input, then at time-step t, the input is x[t]. In this case, the poisson encoder is used. Note that the poisson encoder is random, and can output different spikes even if input is a constant image.

I have also noticed that StereoSpike that is based in SpikingJelly does not use this hyperparameter (https://github.com/urancon/StereoSpike) , so I do not know whether T is an important hyperparameter or not.

StereoSpike resets SNN at each time-step. In this case, T = 1.

T is an important hyperparameter. The fitting capacity of SNN is influenced by T.

slucas03 commented 2 years ago

I don not know if it is because of the images, but I find it very challenging to understand the concept of time-step when the input are images. I mean, in the example you have in your training window 64 images (because batch_size=64) but then you have a T=100. Does that make sense? Would not it be better to have a T=batc_size in order to feed in each tick of the clock a different image of your batch through the network?

fangwei123456 commented 2 years ago

I mean, in the example you have in your training window 64 images (because batch_size=64) but then you have a T=100. Does that make sense?

The calculation in batch dimension is independent. When batch_size = 64, you can think that you run 64 same SNNs with 64 different intputs.

import torch
import torch.nn as nn
from spikingjelly.clock_driven import functional, layer, neuron
snn = nn.Sequential(
    nn.Linear(28, 10, bias=False),
    neuron.IFNode(v_threshold=0.01)
)

N = 64
T = 8
M = 28
x = torch.rand([T, N, M])

with torch.no_grad():
    y1 = 0.
    for t in range(T):
        y1 += snn(x[t])
    functional.reset_net(snn)

    y2 = []
    for i in range(N):
        y2.append(0.)
        for t in range(T):
            y2[-1] += snn(x[t, i])
        y2[-1].unsqueeze_(0)
        functional.reset_net(snn)
    y2 = torch.cat(y2)
    print((y1 - y2).abs().max())
slucas03 commented 2 years ago

Sorry, I have been testing different things during the past two days. With the example that you have given me you mean to say that this:

y1 = 0. for t in range(T): y1 += snn(x[t]) functional.reset_net(snn)

is the same as this:

y2 = [] for i in range(N): y2.append(0.) for t in range(T): y2[-1] += snn(x[t, i]) y2[-1].unsqueeze_(0) functional.reset_net(snn) y2 = torch.cat(y2) print((y1 - y2).abs().max())

So, if you want to use a TimeStep >1 is very important that the input data has the correct structure, such as your example x=..([T,N,M]) , isn't it?

fangwei123456 commented 2 years ago

Yes. For the moment, I notice that you use the step-by-step mode.

In step-by-step mode, all data should have shape [N, *]. T is not contained in data's dimension.

Refer to Propagation Pattern for more details.

slucas03 commented 2 years ago

Yes I have been doing like that and I had already read #103 but anyway thank you very much for all the information that your are providing me :)

Now, I am facing a problem that the weights of my network are not changed, not updated, so I'm going to see if I'm able to detect the problem and fix it.

fangwei123456 commented 2 years ago

I'm going to see if I'm able to detect the problem and fix it.

OK, you can also upload your codes and we may be able to help you.

slucas03 commented 2 years ago

Well, as I am starting with SpikingJelly and I am using the code of the following:

https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/clock_driven/examples/lif_fc_mnist.py .

I only change a couple of things to adapt it to my input data. But, for example in the aforementioned example's code I add two lines to see the evolution of the weights and I do not see any change in them. I do not know why.

for epoch in range(train_epochs): 
    net.train() 
    for param in net.parameters():
        print(param)
    for img, label in tqdm(train_data_loader):
        img = img.to(device)
        label = label.to(device)
        label_one_hot = F.one_hot(label, 10).float() 
        optimizer.zero_grad()

        for t in range(T): 
            if t == 0:
                out_spikes_counter = net(encoder(img).float()
            else:
                out_spikes_counter += net(encoder(img).float())

        out_spikes_counter_frequency = out_spikes_counter / T
        loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)
        loss.backward()
        optimizer.step()

        for param2 in net.parameters():
            print(param2)

        functional.reset_net(net)

The lines that I add are:

     for param in net.parameters():
        print(param)

They are okey, are not they?

fangwei123456 commented 2 years ago

Yes, but if the change of param is too small (e.g., learning rate is to small), you can not distinguish the difference from printing. You need to do this:

import copy
param_before = copy.deepcopy(net.state_dict())

# one training iteration ....

param_after = net.state_dict()
with torch.no_grad():
    for key in param_before.keys():
        print(key, param_before[key] - param_after[key])
slucas03 commented 2 years ago

I have written that code in the example of MNIST and what I got back is the following:

1.weight tensor([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')

Is that OK or is there any problem?

fangwei123456 commented 2 years ago

It should not be all zero. Can you print loss to check if it is nan?

slucas03 commented 2 years ago

Yes, it is very strange hence it was my question. Here, you have the value of loss in the first iteration:

Captura2

It is not nan or zero, so I believe that the weights may be changed.

fangwei123456 commented 2 years ago

OK, can you upload the complete codes?

slucas03 commented 2 years ago

Yes, of course. Anyway it's from the MNIST example, the only thing that I have changed is the addition of what you told me yesterday. I get the train accuracy and test accuracy percentages increased but I wanted to go a step further and analyze the changes in the weights and that's when I saw that there was no change. Something that makes no sense for me and therefore I do not rule out that I am doing something wrong.

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

dataset_dir = './'
log_dir = './'
model_output_dir = './'
batch_size = 64
T = 100  # Timesteps
learning_rate = 0.001
tau = 100.0  # Membrane time constant, tau.
train_epochs = 100
writer = SummaryWriter(log_dir)

train_dataset = torchvision.datasets.MNIST(
    root=dataset_dir,
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
test_dataset = torchvision.datasets.MNIST(
    root=dataset_dir,
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

train_data_loader = data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)
test_data_loader = data.DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False
)

net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 10, bias=False),
    neuron.LIFNode(tau=tau)
)
net = net.to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
encoder = encoding.PoissonEncoder()

train_times = 0
max_test_accuracy = 0

test_accs = []
train_accs = []
param_before = copy.deepcopy(net.state_dict())
for epoch in range(train_epochs):
    print("Epoch {}:".format(epoch))
    print("Training...")
    train_correct_sum = 0
    train_sum = 0
    net.train() 
    for img, label in tqdm(train_data_loader):
        img = img.to(device)
        label = label.to(device)
        label_one_hot = F.one_hot(label, 10).float() 

        optimizer.zero_grad()
        for t in range(T): 
            if t == 0:
                out_spikes_counter = net(encoder(img).float())
            else:
                out_spikes_counter += net(encoder(img).float())

        out_spikes_counter_frequency = out_spikes_counter / T
        loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)
        loss.backward()
        optimizer.step()

        param_after = net.state_dict()
        with torch.no_grad():
            for key in param_before.keys():
                print(key, param_before[key] - param_after[key])
        functional.reset_net(net)

Currently my code is exactly like this one, only changing the input data. Later I will change more hyperparameters, but now I would want to understand very well your code ir order to apply it correctly and to achieve good results.

fangwei123456 commented 2 years ago

Hi, you can try this:


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
import numpy as np
from spikingjelly.clock_driven import neuron, encoding, functional
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import copy

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

dataset_dir = '/datasets/MNIST'
log_dir = './'
model_output_dir = './'
batch_size = 64
T = 100  # Timesteps
learning_rate = 0.001
tau = 100.0  # Membrane time constant, tau.
train_epochs = 100
writer = SummaryWriter(log_dir)

train_dataset = torchvision.datasets.MNIST(
    root=dataset_dir,
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
test_dataset = torchvision.datasets.MNIST(
    root=dataset_dir,
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

train_data_loader = data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)
test_data_loader = data.DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False
)

net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 10, bias=False),
    neuron.LIFNode(tau=tau)
)
net = net.to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
encoder = encoding.PoissonEncoder()

train_times = 0
max_test_accuracy = 0

test_accs = []
train_accs = []

for epoch in range(train_epochs):
    print("Epoch {}:".format(epoch))
    print("Training...")
    train_correct_sum = 0
    train_sum = 0
    net.train() 
    for img, label in train_data_loader:
        img = img.to(device)
        label = label.to(device)
        label_one_hot = F.one_hot(label, 10).float() 

        optimizer.zero_grad()
        param_before = copy.deepcopy(net.state_dict())
        for t in range(T): 
            if t == 0:
                out_spikes_counter = net(encoder(img).float())
            else:
                out_spikes_counter += net(encoder(img).float())

        out_spikes_counter_frequency = out_spikes_counter / T
        loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)
        loss.backward()
        optimizer.step()

        param_after = net.state_dict()
        with torch.no_grad():
            for key in param_before.keys():
                print(key, (param_before[key] - param_after[key]).abs().sum())
        functional.reset_net(net)

And the outputs are:

1.weight tensor(2.4378, device='cuda:0')
1.weight tensor(2.5037, device='cuda:0')
1.weight tensor(2.5080, device='cuda:0')
1.weight tensor(2.4979, device='cuda:0')
1.weight tensor(2.4738, device='cuda:0')
1.weight tensor(2.4785, device='cuda:0')
1.weight tensor(2.5046, device='cuda:0')
1.weight tensor(2.4819, device='cuda:0')
1.weight tensor(2.4760, device='cuda:0')
1.weight tensor(2.4436, device='cuda:0')
1.weight tensor(2.4366, device='cuda:0')

If you print the error matrix param_before[key] - param_after[key], you will find many zeros. I think there are only few weights are changed.

slucas03 commented 2 years ago

Yes, it works!!! Thank you very much!!

I think there are only few weights are changed.

Could it be because some weights have converged from the beginning or because some weights have not influenced?

fangwei123456 commented 2 years ago

I think it is caused by 0 gradient, which is also some weights have not influenced. You can try to print weight.grad.

slucas03 commented 2 years ago

I have found this:

The gradients of w1 are not all zero, there are simply a lot of zeros, especially around the border, because the MNIST images have a lot of black pixels (zeros). When multiplying with zero, the resulting gradients are also zero. By printing w1.grad you only see a very small part of the values (borders), and you just can't see the non-zero values.

It is from https://stackoverflow.com/questions/62361956/why-do-loss-backward-and-weight-grad-return-a-tensor-containing-all-zeros and it makes sense, does not it?

fangwei123456 commented 2 years ago

Yes, the dY/dW is X, and X has many zeros.

slucas03 commented 2 years ago

Thank you for all your support!!