NNgen / nngen

NNgen: A Fully-Customizable Hardware Synthesis Compiler for Deep Neural Network
Apache License 2.0
340 stars 46 forks source link

Sample code for ONNX doesn't work for mobilenet_v2 #65

Open MaKaRoIIIKa opened 1 year ago

MaKaRoIIIKa commented 1 year ago

Hello! I'm trying to use your tool by simply replacing the resnet18 model with mobilenet_v2. When converting a model from ONNX I get the error "KeyError: 'Clip'". I tried to add the line "'Clip': act_func.Relu" to the func_map file in the init.py file, but then I get the error "ValueError: input and filter must have a some input channel length as shape[3]: '32' != '1'" Please help resolve these errors. Code:

import nngen as ng
# --------------------
# (1) Represent a DNN model as a dataflow by NNgen operators
# --------------------

# data types
act_dtype = ng.int8
weight_dtype = ng.int8
bias_dtype = ng.int32
scale_dtype = ng.int8
#batchsize = 1

import torch
import torchvision

# model = torchvision.models.resnet18(pretrained=True)
model = torchvision.models.mobilenet_v2(pretrained=True)

# Pytorch to ONNX
onnx_filename = 'mobilenet_v2.onnx'
dummy_input = torch.randn(1, 3, 128, 128)
input_names = ['act']
output_names = ['out']
model.eval()
torch.onnx.export(model, dummy_input, onnx_filename, input_names=input_names, output_names=output_names)

# ONNX to NNgen
dtypes = {}
(outputs, placeholders, variables, constants, operators) = ng.from_onnx(onnx_filename,
                                                                        value_dtypes=dtypes,
                                                                        default_placeholder_dtype=act_dtype,
                                                                        default_variable_dtype=weight_dtype,
                                                                        default_constant_dtype=weight_dtype,
                                                                        default_operator_dtype=act_dtype,
                                                                        default_scale_dtype=scale_dtype,
                                                                        default_bias_dtype=bias_dtype,
                                                                        disable_fusion=False)