Closed xianzhixiong closed 1 month ago
@xianzhixiong 你的pth我们没法加载,没法调试啊。你有原始模型的代码吗?另外,tinynn里面没有issubclass的调用,所以这应该不是tinynn的问题
import torch
import torch.nn as nn
from tinynn.graph.tracer import model_tracer
from tinynn.converter import TFLiteConverter
class SimpleNN(nn.Module):
def __init__(self, input_size, num_classes):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(input_size, 50)
self.bn1 = nn.LayerNorm(50) # 添加批量归一化层
self.fc2 = nn.Linear(50, 30)
self.bn2 = nn.LayerNorm(30) # 添加批量归一化层
self.fc3 = nn.Linear(30, 20)
self.bn3 = nn.LayerNorm(20) # 添加批量归一化层
self.fc4 = nn.Linear(20, num_classes)
def forward(self, x):
x = torch.relu(self.bn1(self.fc1(x))) # 在激活函数之前进行批量归一化
x = torch.relu(self.bn2(self.fc2(x))) # 在激活函数之前进行批量归一化
x = torch.relu(self.bn3(self.fc3(x))) # 在激活函数之前进行批量归一化
x = self.fc4(x)
return x
with model_tracer():
dummy_input = torch.tensor([[0, -48, -92, 60587, 2319]])
dummy_input = dummy_input.float()
print(dummy_input)
model = torch.load('mode2.pth', map_location=torch.device('cpu'))
model.cpu()
model.eval()
output_path = "out/temp.tflite"
model(dummy_input)
print("OK")
converter = TFLiteConverter(model, dummy_input, output_path)
converter.convert()
The above code works, It is recommended to update to the latest TinyNN and try it out.
在运行model时,正常,转换tflite时不正常,出现TypeError: issubclass() arg 2 must be a class, a tuple of classes, or a union mode2.zip