fangwei123456 / Spike-Element-Wise-ResNet

Deep Residual Learning in Spiking Neural Networks
Mozilla Public License 2.0
140 stars 21 forks source link

请问代码是否支持单步推理? #23

Open USTCYYX opened 7 months ago

USTCYYX commented 7 months ago

您好,查看代码后我发现推理使用的是多步推理。请问我能否用 from spikingjelly.clock_driven import functional functional.set_step_mode(net, 's') 方便地将网络改为单步推理模式? 如果不能,您能否提供一个可行的办法使得可以在单步模式下推理?谢谢!

fangwei123456 commented 7 months ago

可以这样,先用functional.set_step_mode作用self,让所有子模块都在单步模式下

USTCYYX commented 7 months ago

您好,我想用spikingjelly训练cifar10和cifar100,我想用在spikingjelly/activation_based/model/spiking_vgg.py中的VGG范例模型,这些模型都是来自于ANN的VGG架构吗?然后将激活函数改成了神经元。

USTCYYX commented 7 months ago

spiking_vgg的文档有点不全,我想问下我这么用是对的吗? net = models.spiking_vgg.__dict__[args.model](neuron=neuron.IFNode(), num_classes=num_classes, surrogate_function=surrogate.ATan())

fangwei123456 commented 7 months ago

neuron=neuron.IFNode()换成neuron=neuron.IFNode

USTCYYX commented 7 months ago

我直接训练vgg16效果比较差,我想试试加上tdbn看看能不能到达那篇文章(Going Deeper With Directly-Trained Larger Spiking Neural Networks)所说的效果,net定义成这样:

spiking_vgg.__dict__[args.model](spiking_neuron=neuron.IFNode, num_classes=num_classes,surrogate_function=surrogate.ATan(),norm_layer=layer.ThresholdDependentBatchNorm2d(alpha=1., v_th=1.))

但是报错了: TypeError: __init__() missing 1 required positional argument: 'num_features' 应该是直接使用layer.ThresholdDependentBatchNorm2d报错了,但是我看spiking_vgg是提供了这个norm_layer自定义的接口的,不知道应该怎么修改。

fangwei123456 commented 7 months ago

norm_layer=layer.ThresholdDependentBatchNorm2d?

fangwei123456 commented 7 months ago

norm_layer需要是一个可调用的对象,而不是一个已经生成的模块吧

USTCYYX commented 7 months ago

norm_layer=layer.ThresholdDependentBatchNorm2d我试过了,会提示输入v_th报错:TypeError: __init__() missing 1 required positional argument: 'v_th' 另外我现在用sj自带的sew_resnet18训练cifar100数据集,并且因为cifar100的size比较小,我将第一层的77卷积核换成了33卷积核。现在的问题是过拟合很严重,在训练集上有99%的精度,但是测试集只有55%左右。我看现在ANN上resnet18能做到75.61%左右的精度。不知道您有没有什么技巧可以缓解过拟合,提高精度。

fangwei123456 commented 7 months ago

CIFAR数据集很容易过拟合,一般只能采用改进网络结构的方式。L2之类的技巧没什么效果。数据增强有一些用,参考下面的代码: https://github.com/fangwei123456/Parallel-Spiking-Neuron/blob/main/cifar10/train_cf10.py

USTCYYX commented 7 months ago

如果要改进网络的话,只能添加dropout,但是resnet原文里是不加dropout的。或者用您文章里的专门为DVScifar10准备的小网络Wide-7B-Net,不知道可不可行。

USTCYYX commented 7 months ago

不知道您有没有用sewresnet做过cifar100,可以提供一点经验吗

fangwei123456 commented 7 months ago

没有试过,但你可以用CIFAR10的网络试试

USTCYYX commented 7 months ago

您好,您说的数据增强方法确实有用,但是似乎有点过了,导致训练集都有点欠拟合。不知道数据增强中有没有参数可以调节数据增强的强度? https://github.com/fangwei123456/Parallel-Spiking-Neuron/blob/main/cifar10/train_cf10.py 中的数据增强办法。

USTCYYX commented 7 months ago

现在比较尴尬的是测试集的精度比训练集的精度更高,这可能是数据增强太强所导致的。您原来的数据增强是针对cifar10的,但是现在我用的数据集是cifar100,可能不需要这么强的数据增强。希望您能告诉我数据增强的代码中有没有参数可供调节,来降低数据增强的效果。

fangwei123456 commented 7 months ago

看一下数据增强的class的构造参数