fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.23k stars 234 forks source link

关于STDP的训练问题 #333

Open rejoie opened 1 year ago

rejoie commented 1 year ago

SpikingJelly版本

0.0.0.0.13

描述

在我参考conv_fashion_minst.py,将STDP应用于MINST数据集上时,发现输出均为0,且模型梯度为0,不训练。同时,我将stdptrace.py文件中的初始化权重`nn.init.constant(net[0].weight.data, 0.4)`注释掉以后,出现了同样的情况,如下所示。 image

我认为是权重初始化的问题,参照教程,当初始化权重过小,突触后神经元不发放脉冲,及$s[j][t]$为0,导致tr_post也为0,最终导致权重更新量也为0。

部分代码如下,参数为:T=4, b=128, channels=32, j=4, lr=0.01, momentum=0.9, tau=2.0, tau_post=2.0, tau_pre=2.0, w_max=1.0, w_min=-1.0

请问我该如何使用STDP在MINST上训练。


class CSNN(nn.Module):
    def __init__(self, T: int, tau: float,channels: int, use_cupy=False):
        super().__init__()
        self.T = T

        self.conv_fc = nn.Sequential(
        layer.Conv2d(1, channels, kernel_size=3, padding=1, bias=False),
        neuron.LIFNode(tau=tau),
        layer.MaxPool2d(2, 2),  # 14 * 14

        layer.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
        neuron.LIFNode(tau=tau),
        layer.MaxPool2d(2, 2),  # 7 * 7

        layer.Flatten(),
        layer.Linear(channels * 7 * 7, channels * 4 * 4, bias=False),
        neuron.LIFNode(tau=tau),

        layer.Linear(channels * 4 * 4, 10, bias=False),
        neuron.LIFNode(tau=tau),
        )

        functional.set_step_mode(self, step_mode='m')

        if use_cupy:
            functional.set_backend(self, backend='cupy')

    def forward(self, x: torch.Tensor):
        # x.shape = [N, C, H, W]
        x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)  # [N, C, H, W] -> [T, N, C, H, W]
        x_seq = self.conv_fc(x_seq)
        fr = x_seq.mean(0)
        return fr

for i in range(net.conv_fc.__len__()):
    if isinstance(net.conv_fc[i], neuron.BaseNode):
        learner = learning.STDPLearner(step_mode='m', synapse=net.conv_fc[i-1], sn=net.conv_fc[i], 
                                    tau_pre=args.tau_pre, tau_post=args.tau_post,
                                    f_pre=f_pre, f_post=f_post)
        learners.append(learner)

for epoch in range(start_epoch, args.epochs):
    start_time = time.time()
    net.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    with torch.no_grad():
        for img, label in train_data_loader:
            optimizer.zero_grad()
            img = img.to(args.device)
            label = label.to(args.device)
            label_onehot = F.one_hot(label, 10).float()

            out_fr = net(img)
            loss = F.mse_loss(out_fr, label_onehot)
            # loss.backward()
            for i in range(learners.__len__()):
                learners[i].step(on_grad=True)
                learners[i].reset()
            optimizer.step()

            train_samples += label.numel()
            train_loss += loss.item() * label.numel()
            train_acc += (out_fr.argmax(1) == label).float().sum().item()

            functional.reset_net(net)

        train_time = time.time()
        train_speed = train_samples / (train_time - start_time)
        train_loss /= train_samples
        train_acc /= train_samples

        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('train_acc', train_acc, epoch)
        lr_scheduler.step()

        print(f'epoch ={epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, max_test_acc ={max_test_acc: .4f}')
        print(f'train speed ={train_speed: .4f} images/s')
        print(f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n')
fangwei123456 commented 1 year ago

直接用STDP训练深度SNN是非常困难的。从已有的文献看,通常是2种做法: 1.用STDP训练只有一层的SNN 2.用深度学习方法训练后,再用STDP微调

rejoie commented 1 year ago

直接用STDP训练深度SNN是非常困难的。从已有的文献看,通常是2种做法: 1.用STDP训练只有一层的SNN 2.用深度学习方法训练后,再用STDP微调

我尝试了只训练单层以及两层SNN,仍会存在网络不训练且输出为0的情况。网络如下

class CSNN(nn.Module):
    def __init__(self, T: int, tau: float,channels: int, use_cupy=False):
        super().__init__()
        self.T = T

        self.conv_fc = nn.Sequential(
        layer.Flatten(),
        # layer.Linear( 28 * 28,  8 * 8, bias=False),
        # neuron.LIFNode(tau=tau),

        layer.Linear( 28 * 28, 10, bias=False),
        neuron.LIFNode(tau=tau),
        )

        functional.set_step_mode(self, step_mode='m')

        if use_cupy:
            functional.set_backend(self, backend='cupy')

    def forward(self, x: torch.Tensor):
        # x.shape = [N, C, H, W]
        x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)  # [N, C, H, W] -> [T, N, C, H, W]
        x_seq = self.conv_fc(x_seq)
        fr = x_seq.mean(0)
        return fr
fangwei123456 commented 1 year ago

上面这个网络,经过精心调整参数,有可能成功训练出,但我没有尝试过

fight-think commented 1 month ago

您好,请问一下只使用STDP进行无监督训练时,最后一层的10个分类输出和label的对应关系是和有监督一样吗?还是说需要通过训练数据确定分类输出和label的关系,因为我训练的时候也遇到了分类的准确率在10%左右(random guess),所以我猜测可能是输出和label的对应方式有问题。在 “Unsupervised learning of digit recognition using spike-timing-dependent plasticity“这篇文章里面,他们拿训练数据中属于某一类别的数据来确定哪些excitatory neurons的平均反馈最大,然后在预测时用这些神经元的反馈来确定类别。如果是拿最后一层的10个分类输出来看对应类别的数据哪个输出的平均反馈最大,可能会存在两个不同类别的数据对应同一个输出的问题。所以想问一下您,只用STDP进行无监督训练,测试数据的label怎么确定?