fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.24k stars 235 forks source link

使用自定义激活函数后,把ann转换成snn失败 #461

Open trefoil0219 opened 8 months ago

trefoil0219 commented 8 months ago

pytorch版本:2.0.0+cpu

#自定义算子
class CRELU(nn.Module):  # 
    @staticmethod
    def forward(x):
        x = torch.clamp(x, min=0, max=1)
        return x
    def backward(self, grad_output):
        input, = self.saved_tensors 
        grad_input = grad_output.clone()
        if (input < 0) or (input > 1):
            grad_input = 0
        return grad_input
class Qt(nn.Module):
    def forward(self, x):
        return torch.where(x < 0, torch.zeros_like(x),torch.where(x > 1, torch.ones_like(x), x))
    def backward(self, grad_output):
        return grad_output
#模型结构
import torch
import torch.nn as nn

class CNNforArousal(nn.Module):
    def __init__(self):
        super(CNNforArousal, self).__init__()

        self.conv1 = nn.Conv1d(32, 16, kernel_size=5,stride=1) 
        self.bn1 = nn.BatchNorm1d(16)
        self.crelu1 = CRELU() 
        self.qt1 = Qt()

        self.conv2 = nn.Conv1d(16, 32, kernel_size=5,stride=1)
        self.bn2 = nn.BatchNorm1d(32) 
        self.crelu2 = CRELU()
        self.qt2 = Qt()

        self.maxpool1 = nn.MaxPool1d(kernel_size=2,stride=2)

        self.conv3 = nn.Conv1d(in_channels=32,out_channels=32,kernel_size=5,stride=1) 
        self.bn3 = nn.BatchNorm1d(32)
        self.crelu3 = CRELU()
        self.qt3 = Qt()

        self.maxpool2 = nn.MaxPool1d(kernel_size=2,stride=2)

        self.fc1 = nn.Linear(2432, 128)
        self.qt3 = Qt()
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.crelu1(x) 
        x = self.qt1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.crelu2(x)
        x = self.qt2(x)

        x = self.maxpool1(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.crelu3(x) 
        x = self.qt3(x)
        x = self.maxpool2(x)
#         print(x.view(x.shape[0], -1))
        x = self.fc1(x.flatten())
        x = self.qt3(x)
        x = self.fc2(x)
        x = torch.unsqueeze(x,0)
        return x
#转换代码
from spikingjelly.activation_based import ann2snn
model_converter = ann2snn.Converter(mode='max', dataloader=train_dataloader)
snn_model = model_converter(model)
#转换代码的输出结果
snn_model
#输出以下结果
CNNforArousal(
  (conv1): Conv1d(32, 16, kernel_size=(5,), stride=(1,))
  (conv2): Conv1d(16, 32, kernel_size=(5,), stride=(1,))
  (maxpool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(32, 32, kernel_size=(5,), stride=(1,))
  (maxpool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=2432, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=2, bias=True)
)
Lyu6PosHao commented 6 months ago

你好,现在ANN2SNN转换的代码逻辑是,把ReLU替换为snn tailor。 所以如果您自定义了一个激活函数,就需要手动改写代码了。 把/remote-home/lvliuzh/temp/spikingjelly/spikingjelly/activation_based/ann2snn/converter.py里涉及 nn.ReLU 的类型判断,都改为您自定义的激活函数的类型 @populustremble