fangwei123456 / Spike-Element-Wise-ResNet

Deep Residual Learning in Spiking Neural Networks
Mozilla Public License 2.0
136 stars 21 forks source link

请问代码是否支持单步推理? #23

Open USTCYYX opened 6 months ago

USTCYYX commented 6 months ago

您好,查看代码后我发现推理使用的是多步推理。请问我能否用 from spikingjelly.clock_driven import functional functional.set_step_mode(net, 's') 方便地将网络改为单步推理模式? 如果不能,您能否提供一个可行的办法使得可以在单步模式下推理?谢谢!

fangwei123456 commented 6 months ago

这个时期的SJ框架好像不支持切换

USTCYYX commented 6 months ago

我看到代码中用的神经元都有MultiStep的前缀,如MultiStepIFNode。那请问是否有对应的单步神经元?如果我将多步神经元全部替换为单步神经元,可以实现单步推理吗?因为sj里似乎单步和多步只与神经元层有关,和其他层无关。

fangwei123456 commented 6 months ago

可以这样做,把神经元换成单步神经元,把用seq to ann container包装的无状态层去掉包装即可在单步模式下运行

USTCYYX commented 6 months ago

我没有找到这一版本SJ的文档,您能否提供这一版本里神经元类的文档?

fangwei123456 commented 6 months ago

这个版本可能是04 https://spikingjelly.readthedocs.io/zh-cn/0.0.0.0.4/

USTCYYX commented 6 months ago

针对您提出的解决方案,我截取了一段典型的代码: self.conv2 = layer.SeqToANNContainer( conv3x3(width, width, stride, groups, dilation), norm_layer(width) ) self.sn2 = cext_neuron.MultiStepIFNode(detach_reset=True) 要转到单步,是否应该这样修改: self.conv2 = nn.Sequential( conv3x3(width, width, stride, groups, dilation), norm_layer(width) ) self.sn2 = cext_neuron.IFNode(detach_reset=True)

fangwei123456 commented 6 months ago

是的,这样做即可。

可能引发的小问题是SeqToANNContainer换成Sequential后加载训练好的权重出问题。由于SeqToANNContainer会融合TN,你可以考虑在单步模式下,给SeqToANNContainer注册个钩子,前向传播时输入unsqueeze(0)补充一个T=1的维度,输出在squeeze(0)去掉这个维度

USTCYYX commented 6 months ago

“由于SeqToANNContainer会融合TN,你可以考虑在单步模式下,给SeqToANNContainer注册个钩子,前向传播时输入unsqueeze(0)补充一个T=1的维度,输出在squeeze(0)去掉这个维度” 这里您的意思是不改成nn.Sequential是吗?然后每次输入[1,N,C,H,W],1用来给SeqToANNContainer做融合使用,没有实际意义。

USTCYYX commented 6 months ago

正常单步应该输入[N,C,H,W],然后改成[1,N,C,H,W]去适应SeqToANNContainer

fangwei123456 commented 6 months ago

是的,如果你不加载权重的话就不需要管这个

https://spikingjelly.readthedocs.io/zh-cn/0.0.0.0.14/activation_based/container.html

USTCYYX commented 6 months ago

好的,谢谢您,我应该是要加载权重的,多步训练,然后用单步做推理。另外我还发现似乎有一个问题。 如果现在修改为单步,那么就是: self.conv2 = layer.SeqToANNContainer( conv3x3(width, width, stride, groups, dilation), norm_layer(width) ) self.sn2 = cext_neuron.IFNode(detach_reset=True) 按照做法,self.conv2 输出的是[1,N,C,H,W] 单步的神经元是否能输入[1,N,C,H,W]这样五维度的tensor?按理来说,应该输入[N,C,H,W]给单步神经元.

fangwei123456 commented 6 months ago

不能,所以需要对上一层的输出或者这一层的输入squeeze(0)去掉这个维度

USTCYYX commented 6 months ago

那我能不能这样做,不做任何修改,就用多步神经元和SeqToANNContainer,然后:

for image in dataset:
    total_fr=0.
    for t in range(T):
        image=image.unsqueeze(0)
        output=net(image)
        total_fr+=output.squeeze(0)
    functional.reset_net(net)

问题的关键就是,多步神经元能不能支持这样的多次输入,并在多次输入下记忆膜电位?因为一般来说多步神经元不需要面对这种多次输入的情况。

fangwei123456 commented 6 months ago

多步神经元能不能支持这样的多次输入,并在多次输入下记忆膜电位?

最新的SJ框架可以。这个repo中用的那个我不确定,建议你翻一下源代码

USTCYYX commented 6 months ago

这一版的监视器是不是不支持查看电压呀? 我想直接看每个时间步长的电压就行了。源代码里面没有cext.neuron的部分,我在文档中没有找到。

fangwei123456 commented 6 months ago

最新版的监视器是支持看电压的,用访问成员变量的方式查看

USTCYYX commented 6 months ago

你好,经过我的测试,多步神经元应该是可以保存电压直到下次reset的,您提供版本的SJ没办法查看膜电压,我只能通过输出的spike来判断,具体来说,我用的是下面的代码:

import torch
from torch import nn
from spikingjelly.cext import neuron as cext_neuron
from spikingjelly.clock_driven import layer, functional

x = torch.rand([1000, 1, 784])* 2 + 1
x = x.cuda()

net = nn.Sequential(
                layer.SeqToANNContainer(
                    nn.Linear(784, 100, bias=False)
                ),
                cext_neuron.MultiStepIFNode(alpha=2.0),
                layer.SeqToANNContainer(
                    nn.Linear(100, 10, bias=False)
                ),
                cext_neuron.MultiStepIFNode(alpha=2.0),
            )

net=net.cuda()
out=net(x)

functional.reset_net(net)

for t in range(1000):
    out1=net(x[t,:,:].unsqueeze(dim=0)).squeeze(dim=0)
    print(out1.equal(out[t]))

至少在1000个步长上,神经网络的输出都是一致的。因此我初步认为有记忆性。 但是再sew_resnet上,我进行了测试,发现单步和多步似乎不一致。下面是我的测试代码

import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import math
from torch.cuda import amp
import torch.distributed.optim
import argparse

from spikingjelly.clock_driven import functional
import spiking_resnet, sew_resnet, utils

_seed_ = 2020
import random
random.seed(2020)

torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import numpy as np
np.random.seed(_seed_)

import numpy as np
np.random.seed(_seed_)

def main():
    x=torch.rand(1,3,224,224)
    x=x.to('cuda:0')

    model = sew_resnet.__dict__['sew_resnet18'](T=3, connect_f='ADD')
    model.to('cuda:0')
    out=model(x)
    print(out)

#    model1 = sew_resnet.__dict__['sew_resnet18'](T=1, connect_f='ADD')
#    model1.to('cuda:0')
#    out_all=0.
#    for t in range(3):
#        out1=model1(x)/3.
#        out_all+=out1
#    print(out_all)

if __name__ == "__main__":
    main()

在这里我用了两个网络model和model1,这是因为必须要预先设置T,原因是在model里会根据T去复制输入。

    def _forward_impl(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x.unsqueeze_(0)
        x = x.repeat(self.T, 1, 1, 1, 1)
        x = self.sn1(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 2)
        return self.fc(x.mean(dim=0))

这是sew_resnet的源码,我的理解是前两层是一个输入编码层,并且在最后输出了平均电压。因此按照我的测试方法,model和model1的输出应该是一致的,但是结果不一致。希望您能告诉我原因

USTCYYX commented 6 months ago

我修改了一下代码,在使用预训练后,下面这个代码的单步和多步输出终于一致了:

import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import math
from torch.cuda import amp
import torch.distributed.optim
import argparse

from spikingjelly.clock_driven import functional
import spiking_resnet, sew_resnet, utils

_seed_ = 2020
import random
random.seed(2020)

torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import numpy as np
np.random.seed(_seed_)

import numpy as np
np.random.seed(_seed_)

def main():
    x=torch.rand(1,3,224,224)
    x=x.to('cuda:0')
    checkpoint = torch.load('sew18_checkpoint_319.pth', map_location='cuda:0')

    model = sew_resnet.__dict__['sew_resnet18'](T=3, connect_f='ADD')
    model.load_state_dict(checkpoint['model'])
    model.to('cuda:0')
    model.eval()
    out=model(x)
#    print(out)

    model1 = sew_resnet.__dict__['sew_resnet18'](T=1, connect_f='ADD')
    model1.load_state_dict(checkpoint['model'])
    model1.to('cuda:0')
    out_all=0.
    model1.eval()
    for t in range(3):
        out1=model1(x)/3.
        out_all+=out1
#    print(out_all)

    result=torch.where(torch.abs(out-out_all)<0.01,True,False)
    print(result)
if __name__ == "__main__":
    main()

输出:全部都是True,即误差在可控范围内,这些误差应该是不可去的,属于单步和多步的固有误差。 但是下面的没有加载预权重的代码还是无法得到一致的输出:

import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import math
from torch.cuda import amp
import torch.distributed.optim
import argparse

from spikingjelly.clock_driven import functional
import spiking_resnet, sew_resnet, utils

_seed_ = 2020
import random
random.seed(2020)

torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import numpy as np
np.random.seed(_seed_)

import numpy as np
np.random.seed(_seed_)

def main():
    x=torch.rand(1,3,224,224)
    x=x.to('cuda:0')

    model = sew_resnet.__dict__['sew_resnet18'](T=3, connect_f='ADD')
    model.to('cuda:0')
    model.eval()
    out=model(x)
    print(out)

    model1 = sew_resnet.__dict__['sew_resnet18'](T=1, connect_f='ADD')
    model1.to('cuda:0')
    model1.eval()
    out_all=0.
    for t in range(3):
        out1=model1(x)/3.
        out_all+=out1
    print(out_all)

if __name__ == "__main__":
    main()

不知道您是否知道下面的代码为什么没法得到一致的输出。

fangwei123456 commented 6 months ago

https://github.com/fangwei123456/Spike-Element-Wise-ResNet/blob/cfc547c5e8196b1a16ebc596509b4dffea7c306f/imagenet/sew_resnet.py#L223

是不是这里导致的?前向传播的写法默认按多步来,最终的输出被直接平均了

USTCYYX commented 6 months ago

具体不是很懂,但是fc层是linear的,先平均后输入和先输入后平均得到的结果应该是一致的吧?哪怕有bias也是一样。 请问dvsgesture的weight怎么使用啊?好像有名字不一致的问题。

    checkpoint = torch.load('checkpoint_191.pth', map_location='cuda:0')

    model = smodels.__dict__['PlainNet'](connect_f='ADD')
    model.load_state_dict(checkpoint['model'])
    model.to('cuda:0')

报错: Missing key(s) in state_dict: "conv.0.0.module.1.weight", "conv.0.0.module.1.bias", "conv.0.0.module.1.running_mean", "conv.0.0.module.1.running_var", "conv.1.conv.0.0.module.1.weight", "conv.1.conv.0.0.module.1.bias", "conv.1.conv.0.0.module.1.running_mean", "conv.1.conv.0.0.module.1.running_var", "conv.1.conv.1.0.module.1.weight", "conv.1.conv.1.0.module.1.bias", "conv.1.conv.1.0.module.1.running_mean", "conv.1.conv.1.0.module.1.running_var", "conv.3.conv.0.0.module.1.weight", "conv.3.conv.0.0.module.1.bias", "conv.3.conv.0.0.module.1.running_mean", "conv.3.conv.0.0.module.1.running_var", "conv.3.conv.1.0.module.1.weight", "conv.3.conv.1.0.module.1.bias", "conv.3.conv.1.0.module.1.running_mean", "conv.3.conv.1.0.module.1.running_var", "conv.5.conv.0.0.module.1.weight", "conv.5.conv.0.0.module.1.bias", "conv.5.conv.0.0.module.1.running_mean", "conv.5.conv.0.0.module.1.running_var", "conv.5.conv.1.0.module.1.weight", "conv.5.conv.1.0.module.1.bias", "conv.5.conv.1.0.module.1.running_mean", "conv.5.conv.1.0.module.1.running_var", "conv.7.conv.0.0.module.1.weight", "conv.7.conv.0.0.module.1.bias", "conv.7.conv.0.0.module.1.running_mean", "conv.7.conv.0.0.module.1.running_var", "conv.7.conv.1.0.module.1.weight", "conv.7.conv.1.0.module.1.bias", "conv.7.conv.1.0.module.1.running_mean", "conv.7.conv.1.0.module.1.running_var", "conv.9.conv.0.0.module.1.weight", "conv.9.conv.0.0.module.1.bias", "conv.9.conv.0.0.module.1.running_mean", "conv.9.conv.0.0.module.1.running_var", "conv.9.conv.1.0.module.1.weight", "conv.9.conv.1.0.module.1.bias", "conv.9.conv.1.0.module.1.running_mean", "conv.9.conv.1.0.module.1.running_var", "conv.11.conv.0.0.module.1.weight", "conv.11.conv.0.0.module.1.bias", "conv.11.conv.0.0.module.1.running_mean", "conv.11.conv.0.0.module.1.running_var", "conv.11.conv.1.0.module.1.weight", "conv.11.conv.1.0.module.1.bias", "conv.11.conv.1.0.module.1.running_mean", "conv.11.conv.1.0.module.1.running_var", "conv.13.conv.0.0.module.1.weight", "conv.13.conv.0.0.module.1.bias", "conv.13.conv.0.0.module.1.running_mean", "conv.13.conv.0.0.module.1.running_var", "conv.13.conv.1.0.module.1.weight", "conv.13.conv.1.0.module.1.bias", "conv.13.conv.1.0.module.1.running_mean", "conv.13.conv.1.0.module.1.running_var". Unexpected key(s) in state_dict: "conv.0.0.module.1.0.weight", "conv.0.0.module.1.0.bias", "conv.0.0.module.1.0.running_mean", "conv.0.0.module.1.0.running_var", "conv.0.0.module.1.0.num_batches_tracked", "conv.1.conv.0.0.module.1.0.weight", "conv.1.conv.0.0.module.1.0.bias", "conv.1.conv.0.0.module.1.0.running_mean", "conv.1.conv.0.0.module.1.0.running_var", "conv.1.conv.0.0.module.1.0.num_batches_tracked", "conv.1.conv.1.0.module.1.0.weight", "conv.1.conv.1.0.module.1.0.bias", "conv.1.conv.1.0.module.1.0.running_mean", "conv.1.conv.1.0.module.1.0.running_var", "conv.1.conv.1.0.module.1.0.num_batches_tracked", "conv.3.conv.0.0.module.1.0.weight", "conv.3.conv.0.0.module.1.0.bias", "conv.3.conv.0.0.module.1.0.running_mean", "conv.3.conv.0.0.module.1.0.running_var", "conv.3.conv.0.0.module.1.0.num_batches_tracked", "conv.3.conv.1.0.module.1.0.weight", "conv.3.conv.1.0.module.1.0.bias", "conv.3.conv.1.0.module.1.0.running_mean", "conv.3.conv.1.0.module.1.0.running_var", "conv.3.conv.1.0.module.1.0.num_batches_tracked", "conv.5.conv.0.0.module.1.0.weight", "conv.5.conv.0.0.module.1.0.bias", "conv.5.conv.0.0.module.1.0.running_mean", "conv.5.conv.0.0.module.1.0.running_var", "conv.5.conv.0.0.module.1.0.num_batches_tracked", "conv.5.conv.1.0.module.1.0.weight", "conv.5.conv.1.0.module.1.0.bias", "conv.5.conv.1.0.module.1.0.running_mean", "conv.5.conv.1.0.module.1.0.running_var", "conv.5.conv.1.0.module.1.0.num_batches_tracked", "conv.7.conv.0.0.module.1.0.weight", "conv.7.conv.0.0.module.1.0.bias", "conv.7.conv.0.0.module.1.0.running_mean", "conv.7.conv.0.0.module.1.0.running_var", "conv.7.conv.0.0.module.1.0.num_batches_tracked", "conv.7.conv.1.0.module.1.0.weight", "conv.7.conv.1.0.module.1.0.bias", "conv.7.conv.1.0.module.1.0.running_mean", "conv.7.conv.1.0.module.1.0.running_var", "conv.7.conv.1.0.module.1.0.num_batches_tracked", "conv.9.conv.0.0.module.1.0.weight", "conv.9.conv.0.0.module.1.0.bias", "conv.9.conv.0.0.module.1.0.running_mean", "conv.9.conv.0.0.module.1.0.running_var", "conv.9.conv.0.0.module.1.0.num_batches_tracked", "conv.9.conv.1.0.module.1.0.weight", "conv.9.conv.1.0.module.1.0.bias", "conv.9.conv.1.0.module.1.0.running_mean", "conv.9.conv.1.0.module.1.0.running_var", "conv.9.conv.1.0.module.1.0.num_batches_tracked", "conv.11.conv.0.0.module.1.0.weight", "conv.11.conv.0.0.module.1.0.bias", "conv.11.conv.0.0.module.1.0.running_mean", "conv.11.conv.0.0.module.1.0.running_var", "conv.11.conv.0.0.module.1.0.num_batches_tracked", "conv.11.conv.1.0.module.1.0.weight", "conv.11.conv.1.0.module.1.0.bias", "conv.11.conv.1.0.module.1.0.running_mean", "conv.11.conv.1.0.module.1.0.running_var", "conv.11.conv.1.0.module.1.0.num_batches_tracked", "conv.13.conv.0.0.module.1.0.weight", "conv.13.conv.0.0.module.1.0.bias", "conv.13.conv.0.0.module.1.0.running_mean", "conv.13.conv.0.0.module.1.0.running_var", "conv.13.conv.0.0.module.1.0.num_batches_tracked", "conv.13.conv.1.0.module.1.0.weight", "conv.13.conv.1.0.module.1.0.bias", "conv.13.conv.1.0.module.1.0.running_mean", "conv.13.conv.1.0.module.1.0.running_var", "conv.13.conv.1.0.module.1.0.num_batches_tracked".

fangwei123456 commented 6 months ago

应该不是fc mean的问题

但是下面的没有加载预权重的代码还是无法得到一致的输出:

权重初始化不一样

名字不一致的问题

是SeqToANNContainer导致的,解决方法前面提到过:https://github.com/fangwei123456/Spike-Element-Wise-ResNet/issues/23#issuecomment-1897952381

USTCYYX commented 6 months ago

我试了一下,时间长了之后,单步和多步总是会产生一些细小的误差,可能这是无法避免的吧。

fangwei123456 commented 6 months ago

误差是不能避免的。即便是pytorch中一个conv层,把2个bach拼在一块送进去,和分别送入再把输出合并,最后得到的结果都会有很小的差异

USTCYYX commented 6 months ago

前面提到的dvsgesture模型的问题,我是简化了您的代码里的输入weight过程,是不是我哪里写错了? 您的:

    model = smodels.__dict__[args.model](args.connect_f)
    print("Creating model")

    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    if args.adam:
        optimizer = torch.optim.Adam(
            model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(
            model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    if args.amp:
        scaler = amp.GradScaler()
    else:
        scaler = None

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1
        max_test_acc1 = checkpoint['max_test_acc1']
        test_acc5_at_max_test_acc1 = checkpoint['test_acc5_at_max_test_acc1']

我的:

    checkpoint = torch.load('checkpoint_191.pth', map_location='cuda:0')

    model = smodels.__dict__['PlainNet'](connect_f='ADD')
    model.load_state_dict(checkpoint['model'])
    model.to('cuda:0')
fangwei123456 commented 6 months ago

打印一下check point和model的state_dict的keys,对比一下差异

USTCYYX commented 6 months ago

就在前面,主要是后面多了 .0 和 .num_batches_tracked

fangwei123456 commented 6 months ago

.0可能是SeqToANNContainer导致的?在新版框架中,

https://github.com/fangwei123456/spikingjelly/blob/cb1cee00334ebeca101a155aee694252f85543a8/spikingjelly/activation_based/layer.py#L35

SeqToANNContainer继承自Sequential,state dict中不会多一个.0的前缀了,和sequential一致

fangwei123456 commented 6 months ago

建议你用类似的方法修改模型的定义代码,让key和check point一致

USTCYYX commented 6 months ago

您好,我觉得误差还是有点大,尤其是在长步长后,可能直接用现有的代码做单步推理不是很现实。 但是迁移到新的SJ版本上,我觉得也比较困难,不知道您有没有试过用新版本的SJ复现SEW的全部工作,在新版本的SJ上用SEW训练同样的数据集精度会差很多吗?

fangwei123456 commented 6 months ago

没试过在新版上训练imagenet,但精度应该没什么差距。

USTCYYX commented 6 months ago

新版本的SJ里面是不是没有专门给dvsgesture使用的 7BNetc32k3s1-BN-PLIF-{SEW Block-MPk2s2}*7-FC11Wide-7B-Netc64k3s1-BN-PLIF-{SEW Block (c64)-MPk2s2}*4-c128k3s1-BN-PLIF-{SEW Block (c128)-MPk2s2}*3-FC10 呀。我只看到了一些比较大的SEWResnet

fangwei123456 commented 6 months ago

SJ框架里面没有专门写dvs手势那个网络的代码,只写了imagenet

USTCYYX commented 6 months ago

您有空能添加一下这两个网络吗?主要是我自己写有点怕写错。或者我写完以后贴到这里您帮我看看有没有写错。不然到时候精度不一致复现不出来感觉很难受,希望您能理解。

fangwei123456 commented 6 months ago

直接把目前这个文件夹里面的代码复制过去应该就行了吧..

USTCYYX commented 6 months ago

按照您的说法,我觉得是不是改成这样就可以了啊: 源代码(cifar10_dvs中)

import torch
import torch.nn as nn
from spikingjelly.cext.neuron import MultiStepParametricLIFNode
from spikingjelly.clock_driven import layer

def conv3x3(in_channels, out_channels):
    return nn.Sequential(
        layer.SeqToANNContainer(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
        ),
        MultiStepParametricLIFNode(init_tau=2.0, detach_reset=True)
    )

def conv1x1(in_channels, out_channels):
    return nn.Sequential(
        layer.SeqToANNContainer(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
        ),
        MultiStepParametricLIFNode(init_tau=2.0, detach_reset=True)
    )

class SEWBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, connect_f=None):
        super(SEWBlock, self).__init__()
        self.connect_f = connect_f
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),
            conv3x3(mid_channels, in_channels),
        )

    def forward(self, x: torch.Tensor):
        out = self.conv(x)
        if self.connect_f == 'ADD':
            out += x
        elif self.connect_f == 'AND':
            out *= x
        elif self.connect_f == 'IAND':
            out = x * (1. - out)
        else:
            raise NotImplementedError(self.connect_f)

        return out

class PlainBlock(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(PlainBlock, self).__init__()
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),
            conv3x3(mid_channels, in_channels),
        )

    def forward(self, x: torch.Tensor):
        return self.conv(x)

class BasicBlock(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(BasicBlock, self).__init__()
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),

            layer.SeqToANNContainer(
                nn.Conv2d(mid_channels, in_channels, kernel_size=3, padding=1, stride=1, bias=False),
                nn.BatchNorm2d(in_channels),
            ),
        )
        self.sn = MultiStepParametricLIFNode(init_tau=2.0, detach_reset=True)

    def forward(self, x: torch.Tensor):
        return self.sn(x + self.conv(x))

class ResNetN(nn.Module):
    def __init__(self, layer_list, num_classes, connect_f=None):
        super(ResNetN, self).__init__()
        in_channels = 2
        conv = []

        for cfg_dict in layer_list:
            channels = cfg_dict['channels']

            if 'mid_channels' in cfg_dict:
                mid_channels = cfg_dict['mid_channels']
            else:
                mid_channels = channels

            if in_channels != channels:
                if cfg_dict['up_kernel_size'] == 3:
                    conv.append(conv3x3(in_channels, channels))
                elif cfg_dict['up_kernel_size'] == 1:
                    conv.append(conv1x1(in_channels, channels))
                else:
                    raise NotImplementedError

            in_channels = channels

            if 'num_blocks' in cfg_dict:
                num_blocks = cfg_dict['num_blocks']
                if cfg_dict['block_type'] == 'sew':
                    for _ in range(num_blocks):
                        conv.append(SEWBlock(in_channels, mid_channels, connect_f))
                elif cfg_dict['block_type'] == 'plain':
                    for _ in range(num_blocks):
                        conv.append(PlainBlock(in_channels, mid_channels))
                elif cfg_dict['block_type'] == 'basic':
                    for _ in range(num_blocks):
                        conv.append(BasicBlock(in_channels, mid_channels))
                else:
                    raise NotImplementedError

            if 'k_pool' in cfg_dict:
                k_pool = cfg_dict['k_pool']
                conv.append(layer.SeqToANNContainer(nn.MaxPool2d(k_pool, k_pool)))

        conv.append(nn.Flatten(2))

        self.conv = nn.Sequential(*conv)

        with torch.no_grad():
            x = torch.zeros([1, 1, 128, 128])
            for m in self.conv.modules():
                if isinstance(m, nn.MaxPool2d):
                    x = m(x)
            out_features = x.numel() * in_channels

        self.out = nn.Linear(out_features, num_classes, bias=True)

    def forward(self, x: torch.Tensor):
        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]
        x = self.conv(x)
        return self.out(x.mean(0))

def SEWResNet(connect_f):
    layer_list = [
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
    ]
    num_classes = 10
    return ResNetN(layer_list, num_classes, connect_f)

修改后

import torch
import torch.nn as nn
from spikingjelly.activation_based.neuron import ParametricLIFNode
from spikingjelly.activation_based import layer

def conv3x3(in_channels, out_channels):
    return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            ParametricLIFNode(init_tau=2.0, detach_reset=True)
    )

def conv1x1(in_channels, out_channels):
    return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            ParametricLIFNode(init_tau=2.0, detach_reset=True)
    )

class SEWBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, connect_f=None):
        super(SEWBlock, self).__init__()
        self.connect_f = connect_f
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),
            conv3x3(mid_channels, in_channels),
        )

    def forward(self, x: torch.Tensor):
        out = self.conv(x)
        if self.connect_f == 'ADD':
            out += x
        elif self.connect_f == 'AND':
            out *= x
        elif self.connect_f == 'IAND':
            out = x * (1. - out)
        else:
            raise NotImplementedError(self.connect_f)

        return out

class PlainBlock(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(PlainBlock, self).__init__()
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),
            conv3x3(mid_channels, in_channels),
        )

    def forward(self, x: torch.Tensor):
        return self.conv(x)

class BasicBlock(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(BasicBlock, self).__init__()
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),
            nn.Conv2d(mid_channels, in_channels, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(in_channels),
        )
        self.sn = ParametricLIFNode(init_tau=2.0, detach_reset=True)

    def forward(self, x: torch.Tensor):
        return self.sn(x + self.conv(x))

class ResNetN(nn.Module):
    def __init__(self, layer_list, num_classes, connect_f=None):
        super(ResNetN, self).__init__()
        in_channels = 2
        conv = []

        for cfg_dict in layer_list:
            channels = cfg_dict['channels']

            if 'mid_channels' in cfg_dict:
                mid_channels = cfg_dict['mid_channels']
            else:
                mid_channels = channels

            if in_channels != channels:
                if cfg_dict['up_kernel_size'] == 3:
                    conv.append(conv3x3(in_channels, channels))
                elif cfg_dict['up_kernel_size'] == 1:
                    conv.append(conv1x1(in_channels, channels))
                else:
                    raise NotImplementedError

            in_channels = channels

            if 'num_blocks' in cfg_dict:
                num_blocks = cfg_dict['num_blocks']
                if cfg_dict['block_type'] == 'sew':
                    for _ in range(num_blocks):
                        conv.append(SEWBlock(in_channels, mid_channels, connect_f))
                elif cfg_dict['block_type'] == 'plain':
                    for _ in range(num_blocks):
                        conv.append(PlainBlock(in_channels, mid_channels))
                elif cfg_dict['block_type'] == 'basic':
                    for _ in range(num_blocks):
                        conv.append(BasicBlock(in_channels, mid_channels))
                else:
                    raise NotImplementedError

            if 'k_pool' in cfg_dict:
                k_pool = cfg_dict['k_pool']
                conv.append(nn.MaxPool2d(k_pool, k_pool))

        conv.append(nn.Flatten(2))

        self.conv = nn.Sequential(*conv)

        with torch.no_grad():
            x = torch.zeros([1, 1, 128, 128])
            for m in self.conv.modules():
                if isinstance(m, nn.MaxPool2d):
                    x = m(x)
            out_features = x.numel() * in_channels

        self.out = nn.Linear(out_features, num_classes, bias=True)

    def forward(self, x: torch.Tensor):
        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]
        x = self.conv(x)
        return self.out(x.mean(0))

def SEWResNet(connect_f):
    layer_list = [
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
    ]
    num_classes = 10
    return ResNetN(layer_list, num_classes, connect_f)

我把神经元改了,把所有的SeqToANNContainer全部删掉了。

fangwei123456 commented 6 months ago

是的,然后运行在单步模式

USTCYYX commented 6 months ago

有个问题啊,就是新版的sewresnet里的conv3*3调用的是layer里的:

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return layer.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

和这里的不太一样。 然后我这样修改过去的代码能够自由切换单步和多步吗?我发现单步和多部在训练的时候精度也会有差距。

fangwei123456 commented 6 months ago

nn里面的网络不能切换单步多步 layer里面的才可以

USTCYYX commented 6 months ago

哦,怪不得新版SEWResnet里全部换成layer了。 那我现在将nn也改成layer,除了nn.Sequential,是这样的:

import torch
import torch.nn as nn
from spikingjelly.activation_based.neuron import ParametricLIFNode
from spikingjelly.activation_based import layer

def conv3x3(in_channels, out_channels):
    return nn.Sequential(
            layer.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            layer.BatchNorm2d(out_channels),
            ParametricLIFNode(init_tau=2.0, detach_reset=True)
    )

def conv1x1(in_channels, out_channels):
    return nn.Sequential(
            layer.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            layer.BatchNorm2d(out_channels),
            ParametricLIFNode(init_tau=2.0, detach_reset=True)
    )

class SEWBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, connect_f=None):
        super(SEWBlock, self).__init__()
        self.connect_f = connect_f
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),
            conv3x3(mid_channels, in_channels),
        )

    def forward(self, x: torch.Tensor):
        out = self.conv(x)
        if self.connect_f == 'ADD':
            out += x
        elif self.connect_f == 'AND':
            out *= x
        elif self.connect_f == 'IAND':
            out = x * (1. - out)
        else:
            raise NotImplementedError(self.connect_f)

        return out

class PlainBlock(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(PlainBlock, self).__init__()
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),
            conv3x3(mid_channels, in_channels),
        )

    def forward(self, x: torch.Tensor):
        return self.conv(x)

class BasicBlock(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(BasicBlock, self).__init__()
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),
            layer.Conv2d(mid_channels, in_channels, kernel_size=3, padding=1, stride=1, bias=False),
            layer.BatchNorm2d(in_channels),
        )
        self.sn = ParametricLIFNode(init_tau=2.0, detach_reset=True)

    def forward(self, x: torch.Tensor):
        return self.sn(x + self.conv(x))

class ResNetN(nn.Module):
    def __init__(self, layer_list, num_classes, connect_f=None):
        super(ResNetN, self).__init__()
        in_channels = 2
        conv = []

        for cfg_dict in layer_list:
            channels = cfg_dict['channels']

            if 'mid_channels' in cfg_dict:
                mid_channels = cfg_dict['mid_channels']
            else:
                mid_channels = channels

            if in_channels != channels:
                if cfg_dict['up_kernel_size'] == 3:
                    conv.append(conv3x3(in_channels, channels))
                elif cfg_dict['up_kernel_size'] == 1:
                    conv.append(conv1x1(in_channels, channels))
                else:
                    raise NotImplementedError

            in_channels = channels

            if 'num_blocks' in cfg_dict:
                num_blocks = cfg_dict['num_blocks']
                if cfg_dict['block_type'] == 'sew':
                    for _ in range(num_blocks):
                        conv.append(SEWBlock(in_channels, mid_channels, connect_f))
                elif cfg_dict['block_type'] == 'plain':
                    for _ in range(num_blocks):
                        conv.append(PlainBlock(in_channels, mid_channels))
                elif cfg_dict['block_type'] == 'basic':
                    for _ in range(num_blocks):
                        conv.append(BasicBlock(in_channels, mid_channels))
                else:
                    raise NotImplementedError

            if 'k_pool' in cfg_dict:
                k_pool = cfg_dict['k_pool']
                conv.append(layer.MaxPool2d(k_pool, k_pool))

        conv.append(layer.Flatten(2))

        self.conv = nn.Sequential(*conv)

        with torch.no_grad():
            x = torch.zeros([1, 1, 128, 128])
            for m in self.conv.modules():
                if isinstance(m, layer.MaxPool2d):
                    x = m(x)
            out_features = x.numel() * in_channels

        self.out = layer.Linear(out_features, num_classes, bias=True)

    def forward(self, x: torch.Tensor):
        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]
        x = self.conv(x)
        return self.out(x.mean(0))

def SEWResNet(connect_f):
    layer_list = [
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
    ]
    num_classes = 10
    return ResNetN(layer_list, num_classes, connect_f)

不知道现在是否可以自由切换单步多步? 关于maxpool和flatten我不太确定行不行,因为新版SEWResnet里没用layer.flatten,而是用torch.flatten:

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.sn1(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        if self.avgpool.step_mode == 's':
            x = torch.flatten(x, 1)
        elif self.avgpool.step_mode == 'm':
            x = torch.flatten(x, 2)

        x = self.fc(x)

        return x
fangwei123456 commented 6 months ago

现在可以了

USTCYYX commented 6 months ago

为啥在新版的SEWResnet里是用torch.flatten(x, 1)和x = torch.flatten(x, 2)呀,还需要根据self.avgpool.step_mode进行判断,我修改的代码里面直接用layer.Flatten会不会产生不可预知的错误

USTCYYX commented 6 months ago

conv.append(layer.Flatten(2)),修改后的代码中。

fangwei123456 commented 6 months ago

为啥在新版的SEWResnet里是用torch.flatten(x, 1)和x = torch.flatten(x, 2)

需要看输入是[T, N, C, H, W]还是[N, C, H, W],所以flatten的起始维度不同。layer.Flatten会根据单步/多步模式来进行区分。

USTCYYX commented 6 months ago

试着跑了一下,报了一个错,DVSCIFAR10

Traceback (most recent call last):
  File "train.py", line 289, in <module>
    main()
  File "train.py", line 219, in main
    out_fr = net(frame)
  File "/home/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sewresnet/smodels.py", line 126, in forward
    return self.out(x.mean(0))
  File "/home/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2048x1 and 128x10)

下面是部分代码

class ResNetN(nn.Module):
    def __init__(self, layer_list, num_classes, connect_f=None):
        super(ResNetN, self).__init__()
        in_channels = 2
        conv = []

        for cfg_dict in layer_list:
            channels = cfg_dict['channels']

            if 'mid_channels' in cfg_dict:
                mid_channels = cfg_dict['mid_channels']
            else:
                mid_channels = channels

            if in_channels != channels:
                if cfg_dict['up_kernel_size'] == 3:
                    conv.append(conv3x3(in_channels, channels))
                elif cfg_dict['up_kernel_size'] == 1:
                    conv.append(conv1x1(in_channels, channels))
                else:
                    raise NotImplementedError

            in_channels = channels

            if 'num_blocks' in cfg_dict:
                num_blocks = cfg_dict['num_blocks']
                if cfg_dict['block_type'] == 'sew':
                    for _ in range(num_blocks):
                        conv.append(SEWBlock(in_channels, mid_channels, connect_f))
                elif cfg_dict['block_type'] == 'plain':
                    for _ in range(num_blocks):
                        conv.append(PlainBlock(in_channels, mid_channels))
                elif cfg_dict['block_type'] == 'basic':
                    for _ in range(num_blocks):
                        conv.append(BasicBlock(in_channels, mid_channels))
                else:
                    raise NotImplementedError

            if 'k_pool' in cfg_dict:
                k_pool = cfg_dict['k_pool']
                conv.append(layer.MaxPool2d(k_pool, k_pool))

        conv.append(layer.Flatten(2))

        self.conv = nn.Sequential(*conv)

        with torch.no_grad():
            x = torch.zeros([1, 1, 128, 128])
            for m in self.conv.modules():
                if isinstance(m, layer.MaxPool2d):
                    x = m(x)
            out_features = x.numel() * in_channels

        self.out = layer.Linear(out_features, num_classes, bias=True)

    def forward(self, x: torch.Tensor):
        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]
        x = self.conv(x)
        return self.out(x.mean(0))

def SEWResNet(connect_f):
    layer_list = [
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 64, 'up_kernel_size': 1, 'mid_channels': 64, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 128, 'up_kernel_size': 1, 'mid_channels': 128, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
    ]
    num_classes = 10
    return ResNetN(layer_list, num_classes, connect_f)
USTCYYX commented 6 months ago

好像维度不对

fangwei123456 commented 6 months ago

conv.append(layer.Flatten(2)) 改成1看看?还是报错的话建议检查一下每一层的输出shape

USTCYYX commented 6 months ago

十分感谢!好像跑起来了,但是如果我要改成单步的话,上面的代码还要哪些改动呢?

    def forward(self, x: torch.Tensor):
        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]
        x = self.conv(x)
        return self.out(x.mean(0))

应该改成

    def forward(self, x: torch.Tensor):
        all=0.
        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]
        for t in range x.size()[0]:
                y=self.conv(x[t])
                out=self.out(y)
                all+=out
        return all/x.size()[0]

但我不确定前面的怎么改动

        conv.append(layer.Flatten(1))

        self.conv = nn.Sequential(*conv)

        with torch.no_grad():
            x = torch.zeros([1, 1, 128, 128])
            for m in self.conv.modules():
                if isinstance(m, layer.MaxPool2d):
                    x = m(x)
            out_features = x.numel() * in_channels

        self.out = layer.Linear(out_features, num_classes, bias=True)
fangwei123456 commented 6 months ago

functional.set_step_mode作用整个网络,改成单步即可。具体请参考最新版的教程

https://spikingjelly.readthedocs.io/zh-cn/latest/activation_based/basic_concept.html#id4

fangwei123456 commented 6 months ago

单步模式下最后一层应该是return self.out(x),没有mean(0)

USTCYYX commented 6 months ago

意思是这么写吗:

    def forward(self, x: torch.Tensor):
        all=0.
        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]
        for t in range x.size()[0]:
                y=self.conv(x[t])
                out=self.out(y)
                all+=out
        return all

这里的all就是累计电压。