Open trefoil0219 opened 8 months ago
pytorch版本:2.0.0+cpu
#自定义算子 class CRELU(nn.Module): # @staticmethod def forward(x): x = torch.clamp(x, min=0, max=1) return x def backward(self, grad_output): input, = self.saved_tensors grad_input = grad_output.clone() if (input < 0) or (input > 1): grad_input = 0 return grad_input class Qt(nn.Module): def forward(self, x): return torch.where(x < 0, torch.zeros_like(x),torch.where(x > 1, torch.ones_like(x), x)) def backward(self, grad_output): return grad_output
#模型结构 import torch import torch.nn as nn class CNNforArousal(nn.Module): def __init__(self): super(CNNforArousal, self).__init__() self.conv1 = nn.Conv1d(32, 16, kernel_size=5,stride=1) self.bn1 = nn.BatchNorm1d(16) self.crelu1 = CRELU() self.qt1 = Qt() self.conv2 = nn.Conv1d(16, 32, kernel_size=5,stride=1) self.bn2 = nn.BatchNorm1d(32) self.crelu2 = CRELU() self.qt2 = Qt() self.maxpool1 = nn.MaxPool1d(kernel_size=2,stride=2) self.conv3 = nn.Conv1d(in_channels=32,out_channels=32,kernel_size=5,stride=1) self.bn3 = nn.BatchNorm1d(32) self.crelu3 = CRELU() self.qt3 = Qt() self.maxpool2 = nn.MaxPool1d(kernel_size=2,stride=2) self.fc1 = nn.Linear(2432, 128) self.qt3 = Qt() self.fc2 = nn.Linear(128, 2) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.crelu1(x) x = self.qt1(x) x = self.conv2(x) x = self.bn2(x) x = self.crelu2(x) x = self.qt2(x) x = self.maxpool1(x) x = self.conv3(x) x = self.bn3(x) x = self.crelu3(x) x = self.qt3(x) x = self.maxpool2(x) # print(x.view(x.shape[0], -1)) x = self.fc1(x.flatten()) x = self.qt3(x) x = self.fc2(x) x = torch.unsqueeze(x,0) return x
#转换代码 from spikingjelly.activation_based import ann2snn model_converter = ann2snn.Converter(mode='max', dataloader=train_dataloader) snn_model = model_converter(model)
#转换代码的输出结果 snn_model #输出以下结果 CNNforArousal( (conv1): Conv1d(32, 16, kernel_size=(5,), stride=(1,)) (conv2): Conv1d(16, 32, kernel_size=(5,), stride=(1,)) (maxpool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv3): Conv1d(32, 32, kernel_size=(5,), stride=(1,)) (maxpool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (fc1): Linear(in_features=2432, out_features=128, bias=True) (fc2): Linear(in_features=128, out_features=2, bias=True) )
你好,现在ANN2SNN转换的代码逻辑是,把ReLU替换为snn tailor。 所以如果您自定义了一个激活函数,就需要手动改写代码了。 把/remote-home/lvliuzh/temp/spikingjelly/spikingjelly/activation_based/ann2snn/converter.py里涉及 nn.ReLU 的类型判断,都改为您自定义的激活函数的类型 @populustremble
pytorch版本:2.0.0+cpu