Closed SuyueLiu closed 1 year ago
Hello, I have try your code, it's success, please check your code.
Sample code:
# generate onnx model.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
class SEBlock(nn.Module): # Squeeze and Excitation block
def __init__(self, channels, ratio=16):
super(SEBlock, self).__init__()
channels = channels # 输入的feature map通道数
hidden_channels = channels // ratio # 中间过程的通道数,原文reduction ratio设为16
self.attn = nn.Sequential(
nn.AdaptiveAvgPool2d(1), # avgpool
nn.Conv2d(channels, hidden_channels, 1, 1, 0), # 1x1conv,替代linear
nn.ReLU(), # relu
nn.Conv2d(hidden_channels, channels, 1, 1, 0), # 1x1conv,替代linear
nn.Sigmoid() # sigmoid,将输出压缩到(0,1)
)
def forward(self, x):
weights = self.attn(x) # feature map每个通道的重要性权重(0,1),对应原文的sc
return weights * x # 将计算得到的weights与输入的feature map相乘
model = SEBlock(16)
X = torch.randn((1, 16, 50, 50))
torch.onnx.export(model, X, "./models/SEModel.onnx", opset_version=11)
# Convert to tflite
from converter import onnx_converter
res = onnx_converter(
onnx_model_path = "./models/SEModel.onnx",
need_simplify = True,
output_path = "./models/",
target_formats = ['tflite'],
native_groupconv=True
)
print(res)
# output: {'keras': None, 'tflite': './models/SEModel.tflite', 'keras_error': None, 'tflite_error': 2.9802322e-08}
Hi, thanks for your work, I got a problem when converting onnx (pytorch) model to tflite model. Before adding SE module, the difference between onnx model and tflite model is reasonable, but after adding SE module to the original network, the difference became larger, for example, the mean of difference is 3 or even larger
I use the simple SE module