fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.22k stars 233 forks source link

ANN2SNN runs successfully but does not appear to be transforming the model #517

Closed 83517769 closed 2 months ago

83517769 commented 3 months ago

Read before creating a new issue

For faster response

You can @ the corresponding developers for your issue. Here is the division:

Features Developers
Neurons and Surrogate Functions fangwei123456
Yanqi-Chen
CUDA Acceleration fangwei123456
Yanqi-Chen
Reinforcement Learning lucifer2859
ANN to SNN Conversion DingJianhao
Lyu6PosHao
Biological Learning (e.g., STDP) AllenYolk
Others Grasshlw
lucifer2859
AllenYolk
Lyu6PosHao
DingJianhao
Yanqi-Chen
fangwei123456

We are glad to add new developers who are volunteering to help solve issues to the above table.

Issue type

SpikingJelly version

0.0.0.0.14

Description 我试图将一个训练好的ResNet50进行转换成SNN,用的是:model_converter = ann2snn.Converter(mode='99.9%', dataloader=train_loader) snn_model = model_converter(model.encoder),但是发现转出来的模型不像教程中有IFnode类似的神经元。并且我将转换出的输出打印发现貌似并没有进行脉冲化,输出依旧是小数。 这是我打印出的结果:100%|██████████| 390/390 [00:33<00:00, 11.56it/s] ResNet( (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (layer1): Module( (0): Module( (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1)) (shortcut): Module( (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1)) ) ) (1): Module( (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1)) ) (2): Module( (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1)) ) ) (layer2): Module( (0): Module( (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1)) (shortcut): Module( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2)) ) ) (1): Module( (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1)) ) (2): Module( (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1)) ) (3): Module( (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1)) ) ) (layer3): Module( (0): Module( (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1)) (shortcut): Module( (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2)) ) ) (1): Module( (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1)) ) (2): Module( (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1)) ) (3): Module( (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1)) ) (4): Module( (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1)) ) (5): Module( (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1)) ) ) (layer4): Module( (0): Module( (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1)) (shortcut): Module( (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2)) ) ) (1): Module( (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1)) ) (2): Module( (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1)) ) ) (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) )

def forward(self, x): conv1 = self.conv1(x); x = None relu = torch.nn.functional.relu(conv1, inplace = False); conv1 = None layer1_0_conv1 = getattr(self.layer1, "0").conv1(relu) relu_1 = torch.nn.functional.relu(layer1_0_conv1, inplace = False); layer1_0_conv1 = None layer1_0_conv2 = getattr(self.layer1, "0").conv2(relu_1); relu_1 = None relu_2 = torch.nn.functional.relu(layer1_0_conv2, inplace = False); layer1_0_conv2 = None layer1_0_conv3 = getattr(self.layer1, "0").conv3(relu_2); relu_2 = None layer1_0_shortcut_0 = getattr(getattr(self.layer1, "0").shortcut, "0")(relu); relu = None add = layer1_0_conv3 + layer1_0_shortcut_0; layer1_0_conv3 = layer1_0_shortcut_0 = None relu_3 = torch.nn.functional.relu(add, inplace = False); add = None layer1_1_conv1 = getattr(self.layer1, "1").conv1(relu_3) relu_4 = torch.nn.functional.relu(layer1_1_conv1, inplace = False); layer1_1_conv1 = None layer1_1_conv2 = getattr(self.layer1, "1").conv2(relu_4); relu_4 = None relu_5 = torch.nn.functional.relu(layer1_1_conv2, inplace = False); layer1_1_conv2 = None layer1_1_conv3 = getattr(self.layer1, "1").conv3(relu_5); relu_5 = None add_1 = layer1_1_conv3 + relu_3; layer1_1_conv3 = relu_3 = None relu_6 = torch.nn.functional.relu(add_1, inplace = False); add_1 = None layer1_2_conv1 = getattr(self.layer1, "2").conv1(relu_6) relu_7 = torch.nn.functional.relu(layer1_2_conv1, inplace = False); layer1_2_conv1 = None layer1_2_conv2 = getattr(self.layer1, "2").conv2(relu_7); relu_7 = None relu_8 = torch.nn.functional.relu(layer1_2_conv2, inplace = False); layer1_2_conv2 = None layer1_2_conv3 = getattr(self.layer1, "2").conv3(relu_8); relu_8 = None add_2 = layer1_2_conv3 + relu_6; layer1_2_conv3 = relu_6 = None relu_9 = torch.nn.functional.relu(add_2, inplace = False); add_2 = None layer2_0_conv1 = getattr(self.layer2, "0").conv1(relu_9) relu_10 = torch.nn.functional.relu(layer2_0_conv1, inplace = False); layer2_0_conv1 = None layer2_0_conv2 = getattr(self.layer2, "0").conv2(relu_10); relu_10 = None relu_11 = torch.nn.functional.relu(layer2_0_conv2, inplace = False); layer2_0_conv2 = None layer2_0_conv3 = getattr(self.layer2, "0").conv3(relu_11); relu_11 = None layer2_0_shortcut_0 = getattr(getattr(self.layer2, "0").shortcut, "0")(relu_9); relu_9 = None add_3 = layer2_0_conv3 + layer2_0_shortcut_0; layer2_0_conv3 = layer2_0_shortcut_0 = None relu_12 = torch.nn.functional.relu(add_3, inplace = False); add_3 = None layer2_1_conv1 = getattr(self.layer2, "1").conv1(relu_12) relu_13 = torch.nn.functional.relu(layer2_1_conv1, inplace = False); layer2_1_conv1 = None layer2_1_conv2 = getattr(self.layer2, "1").conv2(relu_13); relu_13 = None relu_14 = torch.nn.functional.relu(layer2_1_conv2, inplace = False); layer2_1_conv2 = None layer2_1_conv3 = getattr(self.layer2, "1").conv3(relu_14); relu_14 = None add_4 = layer2_1_conv3 + relu_12; layer2_1_conv3 = relu_12 = None relu_15 = torch.nn.functional.relu(add_4, inplace = False); add_4 = None layer2_2_conv1 = getattr(self.layer2, "2").conv1(relu_15) relu_16 = torch.nn.functional.relu(layer2_2_conv1, inplace = False); layer2_2_conv1 = None layer2_2_conv2 = getattr(self.layer2, "2").conv2(relu_16); relu_16 = None relu_17 = torch.nn.functional.relu(layer2_2_conv2, inplace = False); layer2_2_conv2 = None layer2_2_conv3 = getattr(self.layer2, "2").conv3(relu_17); relu_17 = None add_5 = layer2_2_conv3 + relu_15; layer2_2_conv3 = relu_15 = None relu_18 = torch.nn.functional.relu(add_5, inplace = False); add_5 = None layer2_3_conv1 = getattr(self.layer2, "3").conv1(relu_18) relu_19 = torch.nn.functional.relu(layer2_3_conv1, inplace = False); layer2_3_conv1 = None layer2_3_conv2 = getattr(self.layer2, "3").conv2(relu_19); relu_19 = None relu_20 = torch.nn.functional.relu(layer2_3_conv2, inplace = False); layer2_3_conv2 = None layer2_3_conv3 = getattr(self.layer2, "3").conv3(relu_20); relu_20 = None add_6 = layer2_3_conv3 + relu_18; layer2_3_conv3 = relu_18 = None relu_21 = torch.nn.functional.relu(add_6, inplace = False); add_6 = None layer3_0_conv1 = getattr(self.layer3, "0").conv1(relu_21) relu_22 = torch.nn.functional.relu(layer3_0_conv1, inplace = False); layer3_0_conv1 = None layer3_0_conv2 = getattr(self.layer3, "0").conv2(relu_22); relu_22 = None relu_23 = torch.nn.functional.relu(layer3_0_conv2, inplace = False); layer3_0_conv2 = None layer3_0_conv3 = getattr(self.layer3, "0").conv3(relu_23); relu_23 = None layer3_0_shortcut_0 = getattr(getattr(self.layer3, "0").shortcut, "0")(relu_21); relu_21 = None add_7 = layer3_0_conv3 + layer3_0_shortcut_0; layer3_0_conv3 = layer3_0_shortcut_0 = None relu_24 = torch.nn.functional.relu(add_7, inplace = False); add_7 = None layer3_1_conv1 = getattr(self.layer3, "1").conv1(relu_24) relu_25 = torch.nn.functional.relu(layer3_1_conv1, inplace = False); layer3_1_conv1 = None layer3_1_conv2 = getattr(self.layer3, "1").conv2(relu_25); relu_25 = None relu_26 = torch.nn.functional.relu(layer3_1_conv2, inplace = False); layer3_1_conv2 = None layer3_1_conv3 = getattr(self.layer3, "1").conv3(relu_26); relu_26 = None add_8 = layer3_1_conv3 + relu_24; layer3_1_conv3 = relu_24 = None relu_27 = torch.nn.functional.relu(add_8, inplace = False); add_8 = None layer3_2_conv1 = getattr(self.layer3, "2").conv1(relu_27) relu_28 = torch.nn.functional.relu(layer3_2_conv1, inplace = False); layer3_2_conv1 = None layer3_2_conv2 = getattr(self.layer3, "2").conv2(relu_28); relu_28 = None relu_29 = torch.nn.functional.relu(layer3_2_conv2, inplace = False); layer3_2_conv2 = None layer3_2_conv3 = getattr(self.layer3, "2").conv3(relu_29); relu_29 = None add_9 = layer3_2_conv3 + relu_27; layer3_2_conv3 = relu_27 = None relu_30 = torch.nn.functional.relu(add_9, inplace = False); add_9 = None layer3_3_conv1 = getattr(self.layer3, "3").conv1(relu_30) relu_31 = torch.nn.functional.relu(layer3_3_conv1, inplace = False); layer3_3_conv1 = None layer3_3_conv2 = getattr(self.layer3, "3").conv2(relu_31); relu_31 = None relu_32 = torch.nn.functional.relu(layer3_3_conv2, inplace = False); layer3_3_conv2 = None layer3_3_conv3 = getattr(self.layer3, "3").conv3(relu_32); relu_32 = None add_10 = layer3_3_conv3 + relu_30; layer3_3_conv3 = relu_30 = None relu_33 = torch.nn.functional.relu(add_10, inplace = False); add_10 = None layer3_4_conv1 = getattr(self.layer3, "4").conv1(relu_33) relu_34 = torch.nn.functional.relu(layer3_4_conv1, inplace = False); layer3_4_conv1 = None layer3_4_conv2 = getattr(self.layer3, "4").conv2(relu_34); relu_34 = None relu_35 = torch.nn.functional.relu(layer3_4_conv2, inplace = False); layer3_4_conv2 = None layer3_4_conv3 = getattr(self.layer3, "4").conv3(relu_35); relu_35 = None add_11 = layer3_4_conv3 + relu_33; layer3_4_conv3 = relu_33 = None relu_36 = torch.nn.functional.relu(add_11, inplace = False); add_11 = None layer3_5_conv1 = getattr(self.layer3, "5").conv1(relu_36) relu_37 = torch.nn.functional.relu(layer3_5_conv1, inplace = False); layer3_5_conv1 = None layer3_5_conv2 = getattr(self.layer3, "5").conv2(relu_37); relu_37 = None relu_38 = torch.nn.functional.relu(layer3_5_conv2, inplace = False); layer3_5_conv2 = None layer3_5_conv3 = getattr(self.layer3, "5").conv3(relu_38); relu_38 = None add_12 = layer3_5_conv3 + relu_36; layer3_5_conv3 = relu_36 = None relu_39 = torch.nn.functional.relu(add_12, inplace = False); add_12 = None layer4_0_conv1 = getattr(self.layer4, "0").conv1(relu_39) relu_40 = torch.nn.functional.relu(layer4_0_conv1, inplace = False); layer4_0_conv1 = None layer4_0_conv2 = getattr(self.layer4, "0").conv2(relu_40); relu_40 = None relu_41 = torch.nn.functional.relu(layer4_0_conv2, inplace = False); layer4_0_conv2 = None layer4_0_conv3 = getattr(self.layer4, "0").conv3(relu_41); relu_41 = None layer4_0_shortcut_0 = getattr(getattr(self.layer4, "0").shortcut, "0")(relu_39); relu_39 = None add_13 = layer4_0_conv3 + layer4_0_shortcut_0; layer4_0_conv3 = layer4_0_shortcut_0 = None relu_42 = torch.nn.functional.relu(add_13, inplace = False); add_13 = None layer4_1_conv1 = getattr(self.layer4, "1").conv1(relu_42) relu_43 = torch.nn.functional.relu(layer4_1_conv1, inplace = False); layer4_1_conv1 = None layer4_1_conv2 = getattr(self.layer4, "1").conv2(relu_43); relu_43 = None relu_44 = torch.nn.functional.relu(layer4_1_conv2, inplace = False); layer4_1_conv2 = None layer4_1_conv3 = getattr(self.layer4, "1").conv3(relu_44); relu_44 = None add_14 = layer4_1_conv3 + relu_42; layer4_1_conv3 = relu_42 = None relu_45 = torch.nn.functional.relu(add_14, inplace = False); add_14 = None layer4_2_conv1 = getattr(self.layer4, "2").conv1(relu_45) relu_46 = torch.nn.functional.relu(layer4_2_conv1, inplace = False); layer4_2_conv1 = None layer4_2_conv2 = getattr(self.layer4, "2").conv2(relu_46); relu_46 = None relu_47 = torch.nn.functional.relu(layer4_2_conv2, inplace = False); layer4_2_conv2 = None layer4_2_conv3 = getattr(self.layer4, "2").conv3(relu_47); relu_47 = None add_15 = layer4_2_conv3 + relu_45; layer4_2_conv3 = relu_45 = None relu_48 = torch.nn.functional.relu(add_15, inplace = False); add_15 = None avgpool = self.avgpool(relu_48); relu_48 = None flatten = torch.flatten(avgpool, 1); avgpool = None return flatten

opcode name target args kwargs


placeholder x x () {} call_module conv1 conv1 (x,) {} call_function relu <function relu at 0x000001368BBB10D0> (conv1,) {'inplace': False} call_module layer1_0_conv1 layer1.0.conv1 (relu,) {} call_function relu_1 <function relu at 0x000001368BBB10D0> (layer1_0_conv1,) {'inplace': False} call_module layer1_0_conv2 layer1.0.conv2 (relu_1,) {} call_function relu_2 <function relu at 0x000001368BBB10D0> (layer1_0_conv2,) {'inplace': False} call_module layer1_0_conv3 layer1.0.conv3 (relu_2,) {} call_module layer1_0_shortcut_0 layer1.0.shortcut.0 (relu,) {} call_function add (layer1_0_conv3, layer1_0_shortcut_0) {} call_function relu_3 <function relu at 0x000001368BBB10D0> (add,) {'inplace': False} call_module layer1_1_conv1 layer1.1.conv1 (relu_3,) {} call_function relu_4 <function relu at 0x000001368BBB10D0> (layer1_1_conv1,) {'inplace': False} call_module layer1_1_conv2 layer1.1.conv2 (relu_4,) {} call_function relu_5 <function relu at 0x000001368BBB10D0> (layer1_1_conv2,) {'inplace': False} call_module layer1_1_conv3 layer1.1.conv3 (relu_5,) {} call_function add_1 (layer1_1_conv3, relu_3) {} call_function relu_6 <function relu at 0x000001368BBB10D0> (add_1,) {'inplace': False} call_module layer1_2_conv1 layer1.2.conv1 (relu_6,) {} call_function relu_7 <function relu at 0x000001368BBB10D0> (layer1_2_conv1,) {'inplace': False} call_module layer1_2_conv2 layer1.2.conv2 (relu_7,) {} call_function relu_8 <function relu at 0x000001368BBB10D0> (layer1_2_conv2,) {'inplace': False} call_module layer1_2_conv3 layer1.2.conv3 (relu_8,) {} call_function add_2 (layer1_2_conv3, relu_6) {} call_function relu_9 <function relu at 0x000001368BBB10D0> (add_2,) {'inplace': False} call_module layer2_0_conv1 layer2.0.conv1 (relu_9,) {} call_function relu_10 <function relu at 0x000001368BBB10D0> (layer2_0_conv1,) {'inplace': False} call_module layer2_0_conv2 layer2.0.conv2 (relu_10,) {} call_function relu_11 <function relu at 0x000001368BBB10D0> (layer2_0_conv2,) {'inplace': False} call_module layer2_0_conv3 layer2.0.conv3 (relu_11,) {} call_module layer2_0_shortcut_0 layer2.0.shortcut.0 (relu_9,) {} call_function add_3 (layer2_0_conv3, layer2_0_shortcut_0) {} call_function relu_12 <function relu at 0x000001368BBB10D0> (add_3,) {'inplace': False} call_module layer2_1_conv1 layer2.1.conv1 (relu_12,) {} call_function relu_13 <function relu at 0x000001368BBB10D0> (layer2_1_conv1,) {'inplace': False} call_module layer2_1_conv2 layer2.1.conv2 (relu_13,) {} call_function relu_14 <function relu at 0x000001368BBB10D0> (layer2_1_conv2,) {'inplace': False} call_module layer2_1_conv3 layer2.1.conv3 (relu_14,) {} call_function add_4 (layer2_1_conv3, relu_12) {} call_function relu_15 <function relu at 0x000001368BBB10D0> (add_4,) {'inplace': False} call_module layer2_2_conv1 layer2.2.conv1 (relu_15,) {} call_function relu_16 <function relu at 0x000001368BBB10D0> (layer2_2_conv1,) {'inplace': False} call_module layer2_2_conv2 layer2.2.conv2 (relu_16,) {} call_function relu_17 <function relu at 0x000001368BBB10D0> (layer2_2_conv2,) {'inplace': False} call_module layer2_2_conv3 layer2.2.conv3 (relu_17,) {} call_function add_5 (layer2_2_conv3, relu_15) {} call_function relu_18 <function relu at 0x000001368BBB10D0> (add_5,) {'inplace': False} call_module layer2_3_conv1 layer2.3.conv1 (relu_18,) {} call_function relu_19 <function relu at 0x000001368BBB10D0> (layer2_3_conv1,) {'inplace': False} call_module layer2_3_conv2 layer2.3.conv2 (relu_19,) {} call_function relu_20 <function relu at 0x000001368BBB10D0> (layer2_3_conv2,) {'inplace': False} call_module layer2_3_conv3 layer2.3.conv3 (relu_20,) {} call_function add_6 (layer2_3_conv3, relu_18) {} call_function relu_21 <function relu at 0x000001368BBB10D0> (add_6,) {'inplace': False} call_module layer3_0_conv1 layer3.0.conv1 (relu_21,) {} call_function relu_22 <function relu at 0x000001368BBB10D0> (layer3_0_conv1,) {'inplace': False} call_module layer3_0_conv2 layer3.0.conv2 (relu_22,) {} call_function relu_23 <function relu at 0x000001368BBB10D0> (layer3_0_conv2,) {'inplace': False} call_module layer3_0_conv3 layer3.0.conv3 (relu_23,) {} call_module layer3_0_shortcut_0 layer3.0.shortcut.0 (relu_21,) {} call_function add_7 (layer3_0_conv3, layer3_0_shortcut_0) {} call_function relu_24 <function relu at 0x000001368BBB10D0> (add_7,) {'inplace': False} call_module layer3_1_conv1 layer3.1.conv1 (relu_24,) {} call_function relu_25 <function relu at 0x000001368BBB10D0> (layer3_1_conv1,) {'inplace': False} call_module layer3_1_conv2 layer3.1.conv2 (relu_25,) {} call_function relu_26 <function relu at 0x000001368BBB10D0> (layer3_1_conv2,) {'inplace': False} call_module layer3_1_conv3 layer3.1.conv3 (relu_26,) {} call_function add_8 (layer3_1_conv3, relu_24) {} call_function relu_27 <function relu at 0x000001368BBB10D0> (add_8,) {'inplace': False} call_module layer3_2_conv1 layer3.2.conv1 (relu_27,) {} call_function relu_28 <function relu at 0x000001368BBB10D0> (layer3_2_conv1,) {'inplace': False} call_module layer3_2_conv2 layer3.2.conv2 (relu_28,) {} call_function relu_29 <function relu at 0x000001368BBB10D0> (layer3_2_conv2,) {'inplace': False} call_module layer3_2_conv3 layer3.2.conv3 (relu_29,) {} call_function add_9 (layer3_2_conv3, relu_27) {} call_function relu_30 <function relu at 0x000001368BBB10D0> (add_9,) {'inplace': False} call_module layer3_3_conv1 layer3.3.conv1 (relu_30,) {} call_function relu_31 <function relu at 0x000001368BBB10D0> (layer3_3_conv1,) {'inplace': False} call_module layer3_3_conv2 layer3.3.conv2 (relu_31,) {} call_function relu_32 <function relu at 0x000001368BBB10D0> (layer3_3_conv2,) {'inplace': False} call_module layer3_3_conv3 layer3.3.conv3 (relu_32,) {} call_function add_10 (layer3_3_conv3, relu_30) {} call_function relu_33 <function relu at 0x000001368BBB10D0> (add_10,) {'inplace': False} call_module layer3_4_conv1 layer3.4.conv1 (relu_33,) {} call_function relu_34 <function relu at 0x000001368BBB10D0> (layer3_4_conv1,) {'inplace': False} call_module layer3_4_conv2 layer3.4.conv2 (relu_34,) {} call_function relu_35 <function relu at 0x000001368BBB10D0> (layer3_4_conv2,) {'inplace': False} call_module layer3_4_conv3 layer3.4.conv3 (relu_35,) {} call_function add_11 (layer3_4_conv3, relu_33) {} call_function relu_36 <function relu at 0x000001368BBB10D0> (add_11,) {'inplace': False} call_module layer3_5_conv1 layer3.5.conv1 (relu_36,) {} call_function relu_37 <function relu at 0x000001368BBB10D0> (layer3_5_conv1,) {'inplace': False} call_module layer3_5_conv2 layer3.5.conv2 (relu_37,) {} call_function relu_38 <function relu at 0x000001368BBB10D0> (layer3_5_conv2,) {'inplace': False} call_module layer3_5_conv3 layer3.5.conv3 (relu_38,) {} call_function add_12 (layer3_5_conv3, relu_36) {} call_function relu_39 <function relu at 0x000001368BBB10D0> (add_12,) {'inplace': False} call_module layer4_0_conv1 layer4.0.conv1 (relu_39,) {} call_function relu_40 <function relu at 0x000001368BBB10D0> (layer4_0_conv1,) {'inplace': False} call_module layer4_0_conv2 layer4.0.conv2 (relu_40,) {} call_function relu_41 <function relu at 0x000001368BBB10D0> (layer4_0_conv2,) {'inplace': False} call_module layer4_0_conv3 layer4.0.conv3 (relu_41,) {} call_module layer4_0_shortcut_0 layer4.0.shortcut.0 (relu_39,) {} call_function add_13 (layer4_0_conv3, layer4_0_shortcut_0) {} call_function relu_42 <function relu at 0x000001368BBB10D0> (add_13,) {'inplace': False} call_module layer4_1_conv1 layer4.1.conv1 (relu_42,) {} call_function relu_43 <function relu at 0x000001368BBB10D0> (layer4_1_conv1,) {'inplace': False} call_module layer4_1_conv2 layer4.1.conv2 (relu_43,) {} call_function relu_44 <function relu at 0x000001368BBB10D0> (layer4_1_conv2,) {'inplace': False} call_module layer4_1_conv3 layer4.1.conv3 (relu_44,) {} call_function add_14 (layer4_1_conv3, relu_42) {} call_function relu_45 <function relu at 0x000001368BBB10D0> (add_14,) {'inplace': False} call_module layer4_2_conv1 layer4.2.conv1 (relu_45,) {} call_function relu_46 <function relu at 0x000001368BBB10D0> (layer4_2_conv1,) {'inplace': False} call_module layer4_2_conv2 layer4.2.conv2 (relu_46,) {} call_function relu_47 <function relu at 0x000001368BBB10D0> (layer4_2_conv2,) {'inplace': False} call_module layer4_2_conv3 layer4.2.conv3 (relu_47,) {} call_function add_15 (layer4_2_conv3, relu_45) {} call_function relu_48 <function relu at 0x000001368BBB10D0> (add_15,) {'inplace': False} call_module avgpool avgpool (relu_48,) {} call_function flatten <built-in method flatten of type object at 0x00007FFF4D3E95E0> (avgpool, 1) {} output output output (flatten,) {} out_fr: tensor([[7.7281e-08, 2.3426e-01, 1.8818e-07, ..., 0.0000e+00, 5.0723e-08, 3.2106e-08], [7.7271e-08, 2.7506e-02, 1.8817e-07, ..., 0.0000e+00, 5.0718e-08, 3.2107e-08], [7.7284e-08, 1.0869e-02, 1.8817e-07, ..., 0.0000e+00, 5.0717e-08, 3.2106e-08], ..., [7.7286e-08, 0.0000e+00, 1.8817e-07, ..., 0.0000e+00, 5.0728e-08, 3.2105e-08], [7.7281e-08, 4.5656e-02, 1.8817e-07, ..., 0.0000e+00, 5.0728e-08, 3.2106e-08], [7.7276e-08, 1.5018e-01, 1.8816e-07, ..., 0.0000e+00, 5.0734e-08, 3.2107e-08]], device='cuda:0')

...

Minimal code to reproduce the error/bug

import argparse
import time
import math
import torch
import torch.backends.cudnn as cudnn
from main_ce import set_loader
from torch.utils.data import DataLoader
from util import AverageMeter, accuracy
from spikingjelly.activation_based import ann2snn
from util import set_optimizer
from syops import get_model_complexity_info
from networks.resnet_big import SupConResNet, LinearClassifier
from spikingjelly.activation_based import functional
def parse_option():
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=10,
                        help='print frequency')
    parser.add_argument('--save_freq', type=int, default=50,
                        help='save frequency')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='batch_size')
    parser.add_argument('--num_workers', type=int, default=0,
                        help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=1,
                        help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.002,
                        help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90',
                        help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.2,
                        help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-4,
                        help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum')

    # model dataset
    parser.add_argument('--model', type=str, default='resnet50')
    parser.add_argument('--dataset', type=str, default='cifar10',
                        choices=['cifar10', 'cifar100'], help='dataset')

    # other setting
    parser.add_argument('--cosine', action='store_true',
                        help='using cosine annealing')
    parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')

    parser.add_argument('--ckpt', type=str,
                        default='save\SupCon\cifar10_models1e-5\SupCon_cifar10_resnet50_lr_0.05_decay_0.0001_bsz_32_temp_0.07_trial_0/last.pth',
                        help='path to pre-trained model')
    parser.add_argument('-T', default=5, type=int, help='simulating time-steps')

    opt = parser.parse_args()

    # set the path according to the environment
    opt.data_folder = './datasets/'

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))

    opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'. \
        format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
               opt.batch_size)

    if opt.cosine:
        opt.model_name = '{}_cosine'.format(opt.model_name)

    # warm-up for large-batch training,
    if opt.warm:
        opt.model_name = '{}_warm'.format(opt.model_name)
        opt.warmup_from = 0.01
        opt.warm_epochs = 10
        if opt.cosine:
            eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
            opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                    1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
        else:
            opt.warmup_to = opt.learning_rate

    if opt.dataset == 'cifar10':
        opt.n_cls = 10
    elif opt.dataset == 'cifar100':
        opt.n_cls = 100
    else:
        raise ValueError('dataset not supported: {}'.format(opt.dataset))

    return opt

def set_model(opt):
    model = SupConResNet(name=opt.model)
    criterion = torch.nn.CrossEntropyLoss()

    classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls)

    ckpt = torch.load(opt.ckpt, map_location='cpu')
    state_dict = ckpt['model']

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model.encoder = torch.nn.DataParallel(model.encoder)
        else:
            new_state_dict = {}
            for k, v in state_dict.items():
                k = k.replace("module.", "")
                new_state_dict[k] = v
            state_dict = new_state_dict
        model = model.cuda()
        classifier = classifier.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

        model.load_state_dict(state_dict)
    else:
        raise NotImplementedError('This code requires GPU')

    return model, classifier, criterion

# def split_to_train_test_set(train_ratio: float, origin_dataset: torch.utils.data.Dataset, num_classes: int, random_split: bool = False):
#     '''
#     :param train_ratio: split the ratio of the origin dataset as the train set
#     :type train_ratio: float
#     :param origin_dataset: the origin dataset
#     :type origin_dataset: torch.utils.data.Dataset
#     :param num_classes: total classes number, e.g., ``10`` for the MNIST dataset
#     :type num_classes: int
#     :param random_split: If ``False``, the front ratio of samples in each classes will
#             be included in train set, while the reset will be included in test set.
#             If ``True``, this function will split samples in each classes randomly. The randomness is controlled by
#             ``numpy.randon.seed``
#     :type random_split: int
#     :return: a tuple ``(train_set, test_set)``
#     :rtype: tuple
#     '''
#     label_idx = []
#     for i in range(num_classes):
#         label_idx.append([])
# 
#     for i, item in enumerate(origin_dataset):
#         y = item[1]
#         if isinstance(y, np.ndarray) or isinstance(y, torch.Tensor):
#             y = y.item()
#         label_idx[y].append(i)
#     train_idx = []
#     test_idx = []
#     if random_split:
#         for i in range(num_classes):
#             np.random.shuffle(label_idx[i])
# 
#     for i in range(num_classes):
#         pos = math.ceil(label_idx[i].__len__() * train_ratio)
#         train_idx.extend(label_idx[i][0: pos])
#         test_idx.extend(label_idx[i][pos: label_idx[i].__len__()])
# 
#     return torch.utils.data.Subset(origin_dataset, train_idx), torch.utils.data.Subset(origin_dataset, test_idx)

def train(train_loader, model, classifier, criterion, optimizer, epoch, opt):
    """one epoch training"""
    model.eval()
    classifier.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    out_fr = 0.
    end = time.time()
    for idx, (images, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)

        images = images.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]

        # warm-up learning rate
        #warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)

        # compute loss
        with torch.no_grad():
            for t in range(opt.T):
                out_fr += model(images)

            out_fr = out_fr / opt.T
            #features = model.encoder(images)
        functional.reset_net(model)
        output = classifier(out_fr.detach())
        loss = criterion(output, labels)

        # update metric
        losses.update(loss.item(), bsz)
        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % opt.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1))
            sys.stdout.flush()

    return losses.avg, top1.avg

def validate(val_loader, model, classifier, criterion, opt):
    """validation"""
    model.eval()
    classifier.eval()
    i = 0
    out_fr = 0
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    # Symbol_num = 0

    with torch.no_grad():
        end = time.time()
        for idx, (images, labels) in enumerate(val_loader):
            images = images.float().cuda()
            labels = labels.cuda()
            bsz = labels.shape[0]
            i += len(images)
            # forward
            for t in range(opt.T):
                out_fr += model(images)
            out_fr = out_fr / opt.T
            functional.reset_net(model)
            output = classifier(out_fr.detach())
            loss = criterion(output, labels)

            # update metric
            losses.update(loss.item(), bsz)
            acc1, acc5 = accuracy(output, labels, topk=(1, 5))
            top1.update(acc1[0], bsz)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if idx % opt.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                    idx, len(val_loader), batch_time=batch_time,
                    loss=losses, top1=top1))

    print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
    return losses.avg, top1.avg

def main():
    best_acc = 0
    opt = parse_option()

    # build data loader
    train_loader, val_loader = set_loader(opt)

    # build model and criterion
    model, classifier, criterion = set_model(opt)
    optimizer = set_optimizer(opt, classifier)

    # training routine
    model_converter = ann2snn.Converter(mode='99.9%', dataloader=train_loader)
    #print('开始转换模型')
    snn_model = model_converter(model.encoder)
    #print('成功转换模型至SNN模型')
    print(snn_model)
    snn_model.graph.print_tabular()
    for epoch in range(1, opt.epochs + 1):
        #adjust_learning_rate(opt, optimizer, epoch)

        # train for one epoch
        time1 = time.time()
        loss, acc = train(train_loader, snn_model, classifier, criterion,
                          optimizer, epoch, opt)
        time2 = time.time()
        print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format(
            epoch, time2 - time1, acc))
        # eval for one epoch
        loss, val_acc = validate(val_loader, snn_model, classifier, criterion, opt)
        if val_acc > best_acc:
            best_acc = val_acc

    print('best accuracy: {:.2f}'.format(best_acc))

if __name__ == '__main__':
    main()
# ...
Met4physics commented 3 months ago

ann2snn中的原模型应该使用nn.relu,而不是用nn.functional里的relu函数。类似的问题还有relu层重用等等。ann2snn目前很难做一个general的转换,建议你根据你自己的模型写一个转换方法或者让你的模型适配spikingjelly的转换方法。

83517769 commented 3 months ago

ann2snn中的原模型应该使用nn.relu,而不是用nn.functional里的relu函数。类似的问题还有relu层重用等等。ann2snn目前很难做一个general的转换,建议你根据你自己的模型写一个转换方法或者让你的模型适配spikingjelly的转换方法。

谢谢 我大概理解了 非常感谢