fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
1.22k stars 233 forks source link

ANN2SNN runs successfully but does not appear to be transforming the model #517

Closed 83517769 closed 2 months ago

83517769 commented 3 months ago

Read before creating a new issue

For faster response

You can @ the corresponding developers for your issue. Here is the division:

Features Developers
Neurons and Surrogate Functions fangwei123456
CUDA Acceleration fangwei123456
Reinforcement Learning lucifer2859
ANN to SNN Conversion DingJianhao
Biological Learning (e.g., STDP) AllenYolk
Others Grasshlw

We are glad to add new developers who are volunteering to help solve issues to the above table.

Issue type

SpikingJelly version

Description 我试图将一个训练好的ResNet50进行转换成SNN,用的是:model_converter = ann2snn.Converter(mode='99.9%', dataloader=train_loader) snn_model = model_converter(model.encoder),但是发现转出来的模型不像教程中有IFnode类似的神经元。并且我将转换出的输出打印发现貌似并没有进行脉冲化,输出依旧是小数。 这是我打印出的结果:100%|██████████| 390/390 [00:33<00:00, 11.56it/s] ResNet( (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (layer1): Module( (0): Module( (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1)) (shortcut): Module( (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1)) ) ) (1): Module( (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1)) ) (2): Module( (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1)) ) ) (layer2): Module( (0): Module( (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1)) (shortcut): Module( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2)) ) ) (1): Module( (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1)) ) (2): Module( (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1)) ) (3): Module( (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1)) ) ) (layer3): Module( (0): Module( (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1)) (shortcut): Module( (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2)) ) ) (1): Module( (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1)) ) (2): Module( (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1)) ) (3): Module( (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1)) ) (4): Module( (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1)) ) (5): Module( (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1)) ) ) (layer4): Module( (0): Module( (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1)) (shortcut): Module( (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2)) ) ) (1): Module( (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1)) ) (2): Module( (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1)) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1)) ) ) (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) )

def forward(self, x): conv1 = self.conv1(x); x = None relu = torch.nn.functional.relu(conv1, inplace = False); conv1 = None layer1_0_conv1 = getattr(self.layer1, "0").conv1(relu) relu_1 = torch.nn.functional.relu(layer1_0_conv1, inplace = False); layer1_0_conv1 = None layer1_0_conv2 = getattr(self.layer1, "0").conv2(relu_1); relu_1 = None relu_2 = torch.nn.functional.relu(layer1_0_conv2, inplace = False); layer1_0_conv2 = None layer1_0_conv3 = getattr(self.layer1, "0").conv3(relu_2); relu_2 = None layer1_0_shortcut_0 = getattr(getattr(self.layer1, "0").shortcut, "0")(relu); relu = None add = layer1_0_conv3 + layer1_0_shortcut_0; layer1_0_conv3 = layer1_0_shortcut_0 = None relu_3 = torch.nn.functional.relu(add, inplace = False); add = None layer1_1_conv1 = getattr(self.layer1, "1").conv1(relu_3) relu_4 = torch.nn.functional.relu(layer1_1_conv1, inplace = False); layer1_1_conv1 = None layer1_1_conv2 = getattr(self.layer1, "1").conv2(relu_4); relu_4 = None relu_5 = torch.nn.functional.relu(layer1_1_conv2, inplace = False); layer1_1_conv2 = None layer1_1_conv3 = getattr(self.layer1, "1").conv3(relu_5); relu_5 = None add_1 = layer1_1_conv3 + relu_3; layer1_1_conv3 = relu_3 = None relu_6 = torch.nn.functional.relu(add_1, inplace = False); add_1 = None layer1_2_conv1 = getattr(self.layer1, "2").conv1(relu_6) relu_7 = torch.nn.functional.relu(layer1_2_conv1, inplace = False); layer1_2_conv1 = None layer1_2_conv2 = getattr(self.layer1, "2").conv2(relu_7); relu_7 = None relu_8 = torch.nn.functional.relu(layer1_2_conv2, inplace = False); layer1_2_conv2 = None layer1_2_conv3 = getattr(self.layer1, "2").conv3(relu_8); relu_8 = None add_2 = layer1_2_conv3 + relu_6; layer1_2_conv3 = relu_6 = None relu_9 = torch.nn.functional.relu(add_2, inplace = False); add_2 = None layer2_0_conv1 = getattr(self.layer2, "0").conv1(relu_9) relu_10 = torch.nn.functional.relu(layer2_0_conv1, inplace = False); layer2_0_conv1 = None layer2_0_conv2 = getattr(self.layer2, "0").conv2(relu_10); relu_10 = None relu_11 = torch.nn.functional.relu(layer2_0_conv2, inplace = False); layer2_0_conv2 = None layer2_0_conv3 = getattr(self.layer2, "0").conv3(relu_11); relu_11 = None layer2_0_shortcut_0 = getattr(getattr(self.layer2, "0").shortcut, "0")(relu_9); relu_9 = None add_3 = layer2_0_conv3 + layer2_0_shortcut_0; layer2_0_conv3 = layer2_0_shortcut_0 = None relu_12 = torch.nn.functional.relu(add_3, inplace = False); add_3 = None layer2_1_conv1 = getattr(self.layer2, "1").conv1(relu_12) relu_13 = torch.nn.functional.relu(layer2_1_conv1, inplace = False); layer2_1_conv1 = None layer2_1_conv2 = getattr(self.layer2, "1").conv2(relu_13); relu_13 = None relu_14 = torch.nn.functional.relu(layer2_1_conv2, inplace = False); layer2_1_conv2 = None layer2_1_conv3 = getattr(self.layer2, "1").conv3(relu_14); relu_14 = None add_4 = layer2_1_conv3 + relu_12; layer2_1_conv3 = relu_12 = None relu_15 = torch.nn.functional.relu(add_4, inplace = False); add_4 = None layer2_2_conv1 = getattr(self.layer2, "2").conv1(relu_15) relu_16 = torch.nn.functional.relu(layer2_2_conv1, inplace = False); layer2_2_conv1 = None layer2_2_conv2 = getattr(self.layer2, "2").conv2(relu_16); relu_16 = None relu_17 = torch.nn.functional.relu(layer2_2_conv2, inplace = False); layer2_2_conv2 = None layer2_2_conv3 = getattr(self.layer2, "2").conv3(relu_17); relu_17 = None add_5 = layer2_2_conv3 + relu_15; layer2_2_conv3 = relu_15 = None relu_18 = torch.nn.functional.relu(add_5, inplace = False); add_5 = None layer2_3_conv1 = getattr(self.layer2, "3").conv1(relu_18) relu_19 = torch.nn.functional.relu(layer2_3_conv1, inplace = False); layer2_3_conv1 = None layer2_3_conv2 = getattr(self.layer2, "3").conv2(relu_19); relu_19 = None relu_20 = torch.nn.functional.relu(layer2_3_conv2, inplace = False); layer2_3_conv2 = None layer2_3_conv3 = getattr(self.layer2, "3").conv3(relu_20); relu_20 = None add_6 = layer2_3_conv3 + relu_18; layer2_3_conv3 = relu_18 = None relu_21 = torch.nn.functional.relu(add_6, inplace = False); add_6 = None layer3_0_conv1 = getattr(self.layer3, "0").conv1(relu_21) relu_22 = torch.nn.functional.relu(layer3_0_conv1, inplace = False); layer3_0_conv1 = None layer3_0_conv2 = getattr(self.layer3, "0").conv2(relu_22); relu_22 = None relu_23 = torch.nn.functional.relu(layer3_0_conv2, inplace = False); layer3_0_conv2 = None layer3_0_conv3 = getattr(self.layer3, "0").conv3(relu_23); relu_23 = None layer3_0_shortcut_0 = getattr(getattr(self.layer3, "0").shortcut, "0")(relu_21); relu_21 = None add_7 = layer3_0_conv3 + layer3_0_shortcut_0; layer3_0_conv3 = layer3_0_shortcut_0 = None relu_24 = torch.nn.functional.relu(add_7, inplace = False); add_7 = None layer3_1_conv1 = getattr(self.layer3, "1").conv1(relu_24) relu_25 = torch.nn.functional.relu(layer3_1_conv1, inplace = False); layer3_1_conv1 = None layer3_1_conv2 = getattr(self.layer3, "1").conv2(relu_25); relu_25 = None relu_26 = torch.nn.functional.relu(layer3_1_conv2, inplace = False); layer3_1_conv2 = None layer3_1_conv3 = getattr(self.layer3, "1").conv3(relu_26); relu_26 = None add_8 = layer3_1_conv3 + relu_24; layer3_1_conv3 = relu_24 = None relu_27 = torch.nn.functional.relu(add_8, inplace = False); add_8 = None layer3_2_conv1 = getattr(self.layer3, "2").conv1(relu_27) relu_28 = torch.nn.functional.relu(layer3_2_conv1, inplace = False); layer3_2_conv1 = None layer3_2_conv2 = getattr(self.layer3, "2").conv2(relu_28); relu_28 = None relu_29 = torch.nn.functional.relu(layer3_2_conv2, inplace = False); layer3_2_conv2 = None layer3_2_conv3 = getattr(self.layer3, "2").conv3(relu_29); relu_29 = None add_9 = layer3_2_conv3 + relu_27; layer3_2_conv3 = relu_27 = None relu_30 = torch.nn.functional.relu(add_9, inplace = False); add_9 = None layer3_3_conv1 = getattr(self.layer3, "3").conv1(relu_30) relu_31 = torch.nn.functional.relu(layer3_3_conv1, inplace = False); layer3_3_conv1 = None layer3_3_conv2 = getattr(self.layer3, "3").conv2(relu_31); relu_31 = None relu_32 = torch.nn.functional.relu(layer3_3_conv2, inplace = False); layer3_3_conv2 = None layer3_3_conv3 = getattr(self.layer3, "3").conv3(relu_32); relu_32 = None add_10 = layer3_3_conv3 + relu_30; layer3_3_conv3 = relu_30 = None relu_33 = torch.nn.functional.relu(add_10, inplace = False); add_10 = None layer3_4_conv1 = getattr(self.layer3, "4").conv1(relu_33) relu_34 = torch.nn.functional.relu(layer3_4_conv1, inplace = False); layer3_4_conv1 = None layer3_4_conv2 = getattr(self.layer3, "4").conv2(relu_34); relu_34 = None relu_35 = torch.nn.functional.relu(layer3_4_conv2, inplace = False); layer3_4_conv2 = None layer3_4_conv3 = getattr(self.layer3, "4").conv3(relu_35); relu_35 = None add_11 = layer3_4_conv3 + relu_33; layer3_4_conv3 = relu_33 = None relu_36 = torch.nn.functional.relu(add_11, inplace = False); add_11 = None layer3_5_conv1 = getattr(self.layer3, "5").conv1(relu_36) relu_37 = torch.nn.functional.relu(layer3_5_conv1, inplace = False); layer3_5_conv1 = None layer3_5_conv2 = getattr(self.layer3, "5").conv2(relu_37); relu_37 = None relu_38 = torch.nn.functional.relu(layer3_5_conv2, inplace = False); layer3_5_conv2 = None layer3_5_conv3 = getattr(self.layer3, "5").conv3(relu_38); relu_38 = None add_12 = layer3_5_conv3 + relu_36; layer3_5_conv3 = relu_36 = None relu_39 = torch.nn.functional.relu(add_12, inplace = False); add_12 = None layer4_0_conv1 = getattr(self.layer4, "0").conv1(relu_39) relu_40 = torch.nn.functional.relu(layer4_0_conv1, inplace = False); layer4_0_conv1 = None layer4_0_conv2 = getattr(self.layer4, "0").conv2(relu_40); relu_40 = None relu_41 = torch.nn.functional.relu(layer4_0_conv2, inplace = False); layer4_0_conv2 = None layer4_0_conv3 = getattr(self.layer4, "0").conv3(relu_41); relu_41 = None layer4_0_shortcut_0 = getattr(getattr(self.layer4, "0").shortcut, "0")(relu_39); relu_39 = None add_13 = layer4_0_conv3 + layer4_0_shortcut_0; layer4_0_conv3 = layer4_0_shortcut_0 = None relu_42 = torch.nn.functional.relu(add_13, inplace = False); add_13 = None layer4_1_conv1 = getattr(self.layer4, "1").conv1(relu_42) relu_43 = torch.nn.functional.relu(layer4_1_conv1, inplace = False); layer4_1_conv1 = None layer4_1_conv2 = getattr(self.layer4, "1").conv2(relu_43); relu_43 = None relu_44 = torch.nn.functional.relu(layer4_1_conv2, inplace = False); layer4_1_conv2 = None layer4_1_conv3 = getattr(self.layer4, "1").conv3(relu_44); relu_44 = None add_14 = layer4_1_conv3 + relu_42; layer4_1_conv3 = relu_42 = None relu_45 = torch.nn.functional.relu(add_14, inplace = False); add_14 = None layer4_2_conv1 = getattr(self.layer4, "2").conv1(relu_45) relu_46 = torch.nn.functional.relu(layer4_2_conv1, inplace = False); layer4_2_conv1 = None layer4_2_conv2 = getattr(self.layer4, "2").conv2(relu_46); relu_46 = None relu_47 = torch.nn.functional.relu(layer4_2_conv2, inplace = False); layer4_2_conv2 = None layer4_2_conv3 = getattr(self.layer4, "2").conv3(relu_47); relu_47 = None add_15 = layer4_2_conv3 + relu_45; layer4_2_conv3 = relu_45 = None relu_48 = torch.nn.functional.relu(add_15, inplace = False); add_15 = None avgpool = self.avgpool(relu_48); relu_48 = None flatten = torch.flatten(avgpool, 1); avgpool = None return flatten

opcode name target args kwargs

placeholder x x () {} call_module conv1 conv1 (x,) {} call_function relu <function relu at 0x000001368BBB10D0> (conv1,) {'inplace': False} call_module layer1_0_conv1 layer1.0.conv1 (relu,) {} call_function relu_1 <function relu at 0x000001368BBB10D0> (layer1_0_conv1,) {'inplace': False} call_module layer1_0_conv2 layer1.0.conv2 (relu_1,) {} call_function relu_2 <function relu at 0x000001368BBB10D0> (layer1_0_conv2,) {'inplace': False} call_module layer1_0_conv3 layer1.0.conv3 (relu_2,) {} call_module layer1_0_shortcut_0 layer1.0.shortcut.0 (relu,) {} call_function add (layer1_0_conv3, layer1_0_shortcut_0) {} call_function relu_3 <function relu at 0x000001368BBB10D0> (add,) {'inplace': False} call_module layer1_1_conv1 layer1.1.conv1 (relu_3,) {} call_function relu_4 <function relu at 0x000001368BBB10D0> (layer1_1_conv1,) {'inplace': False} call_module layer1_1_conv2 layer1.1.conv2 (relu_4,) {} call_function relu_5 <function relu at 0x000001368BBB10D0> (layer1_1_conv2,) {'inplace': False} call_module layer1_1_conv3 layer1.1.conv3 (relu_5,) {} call_function add_1 (layer1_1_conv3, relu_3) {} call_function relu_6 <function relu at 0x000001368BBB10D0> (add_1,) {'inplace': False} call_module layer1_2_conv1 layer1.2.conv1 (relu_6,) {} call_function relu_7 <function relu at 0x000001368BBB10D0> (layer1_2_conv1,) {'inplace': False} call_module layer1_2_conv2 layer1.2.conv2 (relu_7,) {} call_function relu_8 <function relu at 0x000001368BBB10D0> (layer1_2_conv2,) {'inplace': False} call_module layer1_2_conv3 layer1.2.conv3 (relu_8,) {} call_function add_2 (layer1_2_conv3, relu_6) {} call_function relu_9 <function relu at 0x000001368BBB10D0> (add_2,) {'inplace': False} call_module layer2_0_conv1 layer2.0.conv1 (relu_9,) {} call_function relu_10 <function relu at 0x000001368BBB10D0> (layer2_0_conv1,) {'inplace': False} call_module layer2_0_conv2 layer2.0.conv2 (relu_10,) {} call_function relu_11 <function relu at 0x000001368BBB10D0> (layer2_0_conv2,) {'inplace': False} call_module layer2_0_conv3 layer2.0.conv3 (relu_11,) {} call_module layer2_0_shortcut_0 layer2.0.shortcut.0 (relu_9,) {} call_function add_3 (layer2_0_conv3, layer2_0_shortcut_0) {} call_function relu_12 <function relu at 0x000001368BBB10D0> (add_3,) {'inplace': False} call_module layer2_1_conv1 layer2.1.conv1 (relu_12,) {} call_function relu_13 <function relu at 0x000001368BBB10D0> (layer2_1_conv1,) {'inplace': False} call_module layer2_1_conv2 layer2.1.conv2 (relu_13,) {} call_function relu_14 <function relu at 0x000001368BBB10D0> (layer2_1_conv2,) {'inplace': False} call_module layer2_1_conv3 layer2.1.conv3 (relu_14,) {} call_function add_4 (layer2_1_conv3, relu_12) {} call_function relu_15 <function relu at 0x000001368BBB10D0> (add_4,) {'inplace': False} call_module layer2_2_conv1 layer2.2.conv1 (relu_15,) {} call_function relu_16 <function relu at 0x000001368BBB10D0> (layer2_2_conv1,) {'inplace': False} call_module layer2_2_conv2 layer2.2.conv2 (relu_16,) {} call_function relu_17 <function relu at 0x000001368BBB10D0> (layer2_2_conv2,) {'inplace': False} call_module layer2_2_conv3 layer2.2.conv3 (relu_17,) {} call_function add_5 (layer2_2_conv3, relu_15) {} call_function relu_18 <function relu at 0x000001368BBB10D0> (add_5,) {'inplace': False} call_module layer2_3_conv1 layer2.3.conv1 (relu_18,) {} call_function relu_19 <function relu at 0x000001368BBB10D0> (layer2_3_conv1,) {'inplace': False} call_module layer2_3_conv2 layer2.3.conv2 (relu_19,) {} call_function relu_20 <function relu at 0x000001368BBB10D0> (layer2_3_conv2,) {'inplace': False} call_module layer2_3_conv3 layer2.3.conv3 (relu_20,) {} call_function add_6 (layer2_3_conv3, relu_18) {} call_function relu_21 <function relu at 0x000001368BBB10D0> (add_6,) {'inplace': False} call_module layer3_0_conv1 layer3.0.conv1 (relu_21,) {} call_function relu_22 <function relu at 0x000001368BBB10D0> (layer3_0_conv1,) {'inplace': False} call_module layer3_0_conv2 layer3.0.conv2 (relu_22,) {} call_function relu_23 <function relu at 0x000001368BBB10D0> (layer3_0_conv2,) {'inplace': False} call_module layer3_0_conv3 layer3.0.conv3 (relu_23,) {} call_module layer3_0_shortcut_0 layer3.0.shortcut.0 (relu_21,) {} call_function add_7 (layer3_0_conv3, layer3_0_shortcut_0) {} call_function relu_24 <function relu at 0x000001368BBB10D0> (add_7,) {'inplace': False} call_module layer3_1_conv1 layer3.1.conv1 (relu_24,) {} call_function relu_25 <function relu at 0x000001368BBB10D0> (layer3_1_conv1,) {'inplace': False} call_module layer3_1_conv2 layer3.1.conv2 (relu_25,) {} call_function relu_26 <function relu at 0x000001368BBB10D0> (layer3_1_conv2,) {'inplace': False} call_module layer3_1_conv3 layer3.1.conv3 (relu_26,) {} call_function add_8 (layer3_1_conv3, relu_24) {} call_function relu_27 <function relu at 0x000001368BBB10D0> (add_8,) {'inplace': False} call_module layer3_2_conv1 layer3.2.conv1 (relu_27,) {} call_function relu_28 <function relu at 0x000001368BBB10D0> (layer3_2_conv1,) {'inplace': False} call_module layer3_2_conv2 layer3.2.conv2 (relu_28,) {} call_function relu_29 <function relu at 0x000001368BBB10D0> (layer3_2_conv2,) {'inplace': False} call_module layer3_2_conv3 layer3.2.conv3 (relu_29,) {} call_function add_9 (layer3_2_conv3, relu_27) {} call_function relu_30 <function relu at 0x000001368BBB10D0> (add_9,) {'inplace': False} call_module layer3_3_conv1 layer3.3.conv1 (relu_30,) {} call_function relu_31 <function relu at 0x000001368BBB10D0> (layer3_3_conv1,) {'inplace': False} call_module layer3_3_conv2 layer3.3.conv2 (relu_31,) {} call_function relu_32 <function relu at 0x000001368BBB10D0> (layer3_3_conv2,) {'inplace': False} call_module layer3_3_conv3 layer3.3.conv3 (relu_32,) {} call_function add_10 (layer3_3_conv3, relu_30) {} call_function relu_33 <function relu at 0x000001368BBB10D0> (add_10,) {'inplace': False} call_module layer3_4_conv1 layer3.4.conv1 (relu_33,) {} call_function relu_34 <function relu at 0x000001368BBB10D0> (layer3_4_conv1,) {'inplace': False} call_module layer3_4_conv2 layer3.4.conv2 (relu_34,) {} call_function relu_35 <function relu at 0x000001368BBB10D0> (layer3_4_conv2,) {'inplace': False} call_module layer3_4_conv3 layer3.4.conv3 (relu_35,) {} call_function add_11 (layer3_4_conv3, relu_33) {} call_function relu_36 <function relu at 0x000001368BBB10D0> (add_11,) {'inplace': False} call_module layer3_5_conv1 layer3.5.conv1 (relu_36,) {} call_function relu_37 <function relu at 0x000001368BBB10D0> (layer3_5_conv1,) {'inplace': False} call_module layer3_5_conv2 layer3.5.conv2 (relu_37,) {} call_function relu_38 <function relu at 0x000001368BBB10D0> (layer3_5_conv2,) {'inplace': False} call_module layer3_5_conv3 layer3.5.conv3 (relu_38,) {} call_function add_12 (layer3_5_conv3, relu_36) {} call_function relu_39 <function relu at 0x000001368BBB10D0> (add_12,) {'inplace': False} call_module layer4_0_conv1 layer4.0.conv1 (relu_39,) {} call_function relu_40 <function relu at 0x000001368BBB10D0> (layer4_0_conv1,) {'inplace': False} call_module layer4_0_conv2 layer4.0.conv2 (relu_40,) {} call_function relu_41 <function relu at 0x000001368BBB10D0> (layer4_0_conv2,) {'inplace': False} call_module layer4_0_conv3 layer4.0.conv3 (relu_41,) {} call_module layer4_0_shortcut_0 layer4.0.shortcut.0 (relu_39,) {} call_function add_13 (layer4_0_conv3, layer4_0_shortcut_0) {} call_function relu_42 <function relu at 0x000001368BBB10D0> (add_13,) {'inplace': False} call_module layer4_1_conv1 layer4.1.conv1 (relu_42,) {} call_function relu_43 <function relu at 0x000001368BBB10D0> (layer4_1_conv1,) {'inplace': False} call_module layer4_1_conv2 layer4.1.conv2 (relu_43,) {} call_function relu_44 <function relu at 0x000001368BBB10D0> (layer4_1_conv2,) {'inplace': False} call_module layer4_1_conv3 layer4.1.conv3 (relu_44,) {} call_function add_14 (layer4_1_conv3, relu_42) {} call_function relu_45 <function relu at 0x000001368BBB10D0> (add_14,) {'inplace': False} call_module layer4_2_conv1 layer4.2.conv1 (relu_45,) {} call_function relu_46 <function relu at 0x000001368BBB10D0> (layer4_2_conv1,) {'inplace': False} call_module layer4_2_conv2 layer4.2.conv2 (relu_46,) {} call_function relu_47 <function relu at 0x000001368BBB10D0> (layer4_2_conv2,) {'inplace': False} call_module layer4_2_conv3 layer4.2.conv3 (relu_47,) {} call_function add_15 (layer4_2_conv3, relu_45) {} call_function relu_48 <function relu at 0x000001368BBB10D0> (add_15,) {'inplace': False} call_module avgpool avgpool (relu_48,) {} call_function flatten <built-in method flatten of type object at 0x00007FFF4D3E95E0> (avgpool, 1) {} output output output (flatten,) {} out_fr: tensor([[7.7281e-08, 2.3426e-01, 1.8818e-07, ..., 0.0000e+00, 5.0723e-08, 3.2106e-08], [7.7271e-08, 2.7506e-02, 1.8817e-07, ..., 0.0000e+00, 5.0718e-08, 3.2107e-08], [7.7284e-08, 1.0869e-02, 1.8817e-07, ..., 0.0000e+00, 5.0717e-08, 3.2106e-08], ..., [7.7286e-08, 0.0000e+00, 1.8817e-07, ..., 0.0000e+00, 5.0728e-08, 3.2105e-08], [7.7281e-08, 4.5656e-02, 1.8817e-07, ..., 0.0000e+00, 5.0728e-08, 3.2106e-08], [7.7276e-08, 1.5018e-01, 1.8816e-07, ..., 0.0000e+00, 5.0734e-08, 3.2107e-08]], device='cuda:0')


Minimal code to reproduce the error/bug

import argparse
import time
import math
import torch
import torch.backends.cudnn as cudnn
from main_ce import set_loader
from import DataLoader
from util import AverageMeter, accuracy
from spikingjelly.activation_based import ann2snn
from util import set_optimizer
from syops import get_model_complexity_info
from networks.resnet_big import SupConResNet, LinearClassifier
from spikingjelly.activation_based import functional
def parse_option():
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=10,
                        help='print frequency')
    parser.add_argument('--save_freq', type=int, default=50,
                        help='save frequency')
    parser.add_argument('--batch_size', type=int, default=128,
    parser.add_argument('--num_workers', type=int, default=0,
                        help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=1,
                        help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.002,
                        help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90',
                        help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.2,
                        help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-4,
                        help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9,

    # model dataset
    parser.add_argument('--model', type=str, default='resnet50')
    parser.add_argument('--dataset', type=str, default='cifar10',
                        choices=['cifar10', 'cifar100'], help='dataset')

    # other setting
    parser.add_argument('--cosine', action='store_true',
                        help='using cosine annealing')
    parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')

    parser.add_argument('--ckpt', type=str,
                        help='path to pre-trained model')
    parser.add_argument('-T', default=5, type=int, help='simulating time-steps')

    opt = parser.parse_args()

    # set the path according to the environment
    opt.data_folder = './datasets/'

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:

    opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'. \
        format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,

    if opt.cosine:
        opt.model_name = '{}_cosine'.format(opt.model_name)

    # warm-up for large-batch training,
    if opt.warm:
        opt.model_name = '{}_warm'.format(opt.model_name)
        opt.warmup_from = 0.01
        opt.warm_epochs = 10
        if opt.cosine:
            eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
            opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                    1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
            opt.warmup_to = opt.learning_rate

    if opt.dataset == 'cifar10':
        opt.n_cls = 10
    elif opt.dataset == 'cifar100':
        opt.n_cls = 100
        raise ValueError('dataset not supported: {}'.format(opt.dataset))

    return opt

def set_model(opt):
    model = SupConResNet(name=opt.model)
    criterion = torch.nn.CrossEntropyLoss()

    classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls)

    ckpt = torch.load(opt.ckpt, map_location='cpu')
    state_dict = ckpt['model']

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model.encoder = torch.nn.DataParallel(model.encoder)
            new_state_dict = {}
            for k, v in state_dict.items():
                k = k.replace("module.", "")
                new_state_dict[k] = v
            state_dict = new_state_dict
        model = model.cuda()
        classifier = classifier.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

        raise NotImplementedError('This code requires GPU')

    return model, classifier, criterion

# def split_to_train_test_set(train_ratio: float, origin_dataset:, num_classes: int, random_split: bool = False):
#     '''
#     :param train_ratio: split the ratio of the origin dataset as the train set
#     :type train_ratio: float
#     :param origin_dataset: the origin dataset
#     :type origin_dataset:
#     :param num_classes: total classes number, e.g., ``10`` for the MNIST dataset
#     :type num_classes: int
#     :param random_split: If ``False``, the front ratio of samples in each classes will
#             be included in train set, while the reset will be included in test set.
#             If ``True``, this function will split samples in each classes randomly. The randomness is controlled by
#             ``numpy.randon.seed``
#     :type random_split: int
#     :return: a tuple ``(train_set, test_set)``
#     :rtype: tuple
#     '''
#     label_idx = []
#     for i in range(num_classes):
#         label_idx.append([])
#     for i, item in enumerate(origin_dataset):
#         y = item[1]
#         if isinstance(y, np.ndarray) or isinstance(y, torch.Tensor):
#             y = y.item()
#         label_idx[y].append(i)
#     train_idx = []
#     test_idx = []
#     if random_split:
#         for i in range(num_classes):
#             np.random.shuffle(label_idx[i])
#     for i in range(num_classes):
#         pos = math.ceil(label_idx[i].__len__() * train_ratio)
#         train_idx.extend(label_idx[i][0: pos])
#         test_idx.extend(label_idx[i][pos: label_idx[i].__len__()])
#     return, train_idx),, test_idx)

def train(train_loader, model, classifier, criterion, optimizer, epoch, opt):
    """one epoch training"""

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    out_fr = 0.
    end = time.time()
    for idx, (images, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)

        images = images.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]

        # warm-up learning rate
        #warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)

        # compute loss
        with torch.no_grad():
            for t in range(opt.T):
                out_fr += model(images)

            out_fr = out_fr / opt.T
            #features = model.encoder(images)
        output = classifier(out_fr.detach())
        loss = criterion(output, labels)

        # update metric
        losses.update(loss.item(), bsz)
        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], bsz)

        # SGD

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % opt.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1))

    return losses.avg, top1.avg

def validate(val_loader, model, classifier, criterion, opt):
    i = 0
    out_fr = 0
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    # Symbol_num = 0

    with torch.no_grad():
        end = time.time()
        for idx, (images, labels) in enumerate(val_loader):
            images = images.float().cuda()
            labels = labels.cuda()
            bsz = labels.shape[0]
            i += len(images)
            # forward
            for t in range(opt.T):
                out_fr += model(images)
            out_fr = out_fr / opt.T
            output = classifier(out_fr.detach())
            loss = criterion(output, labels)

            # update metric
            losses.update(loss.item(), bsz)
            acc1, acc5 = accuracy(output, labels, topk=(1, 5))
            top1.update(acc1[0], bsz)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if idx % opt.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                    idx, len(val_loader), batch_time=batch_time,
                    loss=losses, top1=top1))

    print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
    return losses.avg, top1.avg

def main():
    best_acc = 0
    opt = parse_option()

    # build data loader
    train_loader, val_loader = set_loader(opt)

    # build model and criterion
    model, classifier, criterion = set_model(opt)
    optimizer = set_optimizer(opt, classifier)

    # training routine
    model_converter = ann2snn.Converter(mode='99.9%', dataloader=train_loader)
    snn_model = model_converter(model.encoder)
    for epoch in range(1, opt.epochs + 1):
        #adjust_learning_rate(opt, optimizer, epoch)

        # train for one epoch
        time1 = time.time()
        loss, acc = train(train_loader, snn_model, classifier, criterion,
                          optimizer, epoch, opt)
        time2 = time.time()
        print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format(
            epoch, time2 - time1, acc))
        # eval for one epoch
        loss, val_acc = validate(val_loader, snn_model, classifier, criterion, opt)
        if val_acc > best_acc:
            best_acc = val_acc

    print('best accuracy: {:.2f}'.format(best_acc))

if __name__ == '__main__':
# ...
Met4physics commented 3 months ago


83517769 commented 3 months ago


谢谢 我大概理解了 非常感谢