Open xiaotongnii opened 5 months ago
python api infer shape
auto OptAndShapeAndFold =
FixedPointFn(std::function{OptAndShape}, std::function{FoldConstant},
fixed_point_iters, &converged);
auto sim_model = OptAndShapeAndFold(model);
C++ infer shape impl hal(onnxsim.cpp)
onnx::ModelProto _InferShapes(const onnx::ModelProto& model) {
onnx::ModelProto result;
result.CopyFrom(model);
onnx::shape_inference::InferShapes(result);
return result;
}
onnx inferShape impl Lib\site-packages\onnx\shape_inference
void InferShapes(
ModelProto& m,
const ISchemaRegistry* schema_registry,
const ShapeInferenceOptions& options,
std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name) {
auto opset_imports = GetOpsetImportsFromProto(m);
SymbolTableImpl symbol_table;
ModelLocalFunctionsMap model_local_functions_by_id;
for (const auto& function_proto : m.functions()) {
model_local_functions_by_id.insert(
{GetModelLocalFunctionsMapIdentifier(function_proto.domain(), function_proto.name()), &function_proto});
}
InferShapesImpl(
m.mutable_graph(),
std::unordered_map<std::string, TypeProto*>(0),
opset_imports,
options,
&symbol_table,
model_local_functions_by_id,
schema_registry,
generated_shape_data_by_name,
m.ir_version());
}
void process(NodeProto& n) {
// Resolve domain for node
auto dit = opset_imports.find(n.domain());
if (dit == opset_imports.end()) {
// Both "" and "ai.onnx" refer to the default ONNX domain
if (n.domain() == "") {
dit = opset_imports.find("ai.onnx");
}
if (dit == opset_imports.end()) {
fail_type_inference(
"Cannot infer type and shape for node name ",
n.name(),
". No opset import for domain",
n.domain(),
" optype ",
n.op_type());
}
}
auto domain_version = dit->second;
const auto schema = schema_registry->GetSchema(n.op_type(), domain_version, n.domain());
InferenceContextImpl ctx(
n,
value_types_by_name,
input_data_by_name,
input_sparse_data_by_name,
generated_shape_data_by_name,
&graph_inference_context);
ONNX_TRY {
if (schema) {
if (schema->has_type_and_shape_inference_function()) {
schema->GetTypeAndShapeInferenceFunction()(ctx);
} else if (schema->HasFunction()) {
InferShapeForFunctionNode(
*(schema->GetFunction()),
schema_registry,
ctx,
options,
model_local_functions_map,
symbol_table,
generated_shape_data_by_name);
} else {
// Continue with inference for remaining nodes
return;
}
} else if (model_local_functions_map.size() > 0) {
auto iter = model_local_functions_map.find(GetModelLocalFunctionsMapIdentifier(n.domain(), n.op_type()));
if (iter != model_local_functions_map.end()) {
InferShapeForFunctionNode(
*(iter->second),
schema_registry,
ctx,
options,
model_local_functions_map,
symbol_table,
generated_shape_data_by_name);
} else {
has_unsupported_op = true;
return;
}
} else {
has_unsupported_op = true;
return;
}
}
ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) {
ONNX_HANDLE_EXCEPTION([&]() {
// onnx does not support unsupported/experimental operators
// so it won't consider it as an error
if (!has_unsupported_op && !has_experimental_op) {
inference_errors.push_back(GetErrorWithNodeInfo(n, ex));
}
});
// Continue with inference for remaining nodes
return;
}
ONNX_TRY {
// check the type-equality for input and output
if (options.check_type && schema) {
schema->CheckInputOutputType(ctx);
}
for (int i = 0; i < n.output_size(); ++i) {
// skip type and shape propagation for missing optional outputs.
if (!n.output(i).empty())
updateType(n.output(i), ctx.getOutputType(i));
}
preprocess(n);
// If data propagation is enabled, propagate shape data if it exists.
if (options.enable_data_propagation && schema && schema->has_data_propagation_function()) {
if (generated_shape_data_by_name == nullptr) {
fail_shape_inference(
"Container for generated shape data cannot be nullptr when enable_data_propagation option is set.");
}
DataPropagationContextImpl data_propagation_ctx(
n, value_types_by_name, input_data_by_name, *generated_shape_data_by_name);
schema->GetDataPropagationFunction()(data_propagation_ctx);
}
}
ONNX_CATCH(const std::runtime_error& err) {
ONNX_HANDLE_EXCEPTION([&]() { fail_shape_inference(GetErrorWithNodeInfo(n, err)); });
}
}
schema 是一个描述 ONNX 模型中操作(op)的规范。每个 op 都有一个对应的 schema,它定义了该 op 的输入和输出参数类型、名称、形状等信息。 ONNX 中的每个 op 都有一个唯一的名称,并且每个 op 的 schema 都可以通过 ONNX 官方文档或 ONNX 运行时 API 来获取。 在 ONNXShapeInference 类中的 process 函数中,schema 是指当前正在处理的节点(NodeProto)对应的 schema。通过调用 schema_registry->GetSchema(n.op_type(), domain_version, n.domain()) 方法,可以获取到该节点的 schema
name: "Conv" since_version: "1" description: """ Performs a convolution operation on the input tensor. The output is a tensor with the same rank as the input. The convolution operation can be performed on any number of dimensions, but it is most commonly used for 2D images. """
input [ { name: "X" description: "The input tensor." type: T shape: [D1, ..., Dn] }, { name: "W" description: "The weights tensor." type: T shape: [M1, ..., Mm, K1, ..., Kn] } ]
output [ { name: "Y" description: "The output tensor." type: T shape: [D1, ..., Dn] } ]
where: T = {tensor(float), tensor(double)} n >= 2 m >= 2 D1, ..., Dn are the dimensions of the input tensor M1, ..., Mm are the dimensions of the weights tensor K1, ..., Kn are the kernel sizes
attribute { name: "strides" type: INTS description: "The strides of the convolution operation." default: [1, ..., 1] }
attribute { name: "pads" type: INTS description: "The paddings of the convolution operation." default: [0, ..., 0] }
attribute { name: "dilations" type: INTS description: "The dilations of the convolution operation." default: [1, ..., 1] }
attribute { name: "group" type: INT description: "The number of groups to split the input and output channels into." default: 1 }
网络shape信息,input x shape 为[1,8,80],由Unsqueeze op 后,扩展维度为[1,1,8,80],但网络 Unsqueeze shape为[0,1,0,80],可见 Unsqueeze infer shape 时发生了错误,导致Conv output shape 发生错误。
Node Name: /encoder_embed/Unsqueeze Node OpType: Unsqueeze Input Shapes: [] Output Shapes: [(0, 1, 0, 80)]
Node Name: /encoder_embed/conv/0/Conv Node OpType: Conv Input Shapes: [(0, 1, 0, 80)] Output Shapes: [(0, 8, 0, 80)] Weights Shape: (8, 1, 3, 3)
Node Name: /encoder_embed/conv/3/Sub Node OpType: Sub Input Shapes: [(0, 8, 0, 80)] Output Shapes: [(0, 8, 0, 80)]
Node Name: /encoder_embed/conv/3/Max Node OpType: Max Input Shapes: [(0, 8, 0, 80)] Output Shapes: [(0, 8, 0, 80)]
Node Name: /encoder_embed/conv/3/Sub_1 Node OpType: Sub Input Shapes: [(0, 8, 0, 80)] Output Shapes: [(0, 8, 0, 80)]
Node Name: /encoder_embed/conv/3/Abs Node OpType: Abs Input Shapes: [(0, 8, 0, 80)] Output Shapes: [(0, 8, 0, 80)]
Node Name: /encoder_embed/conv/3/Neg Node OpType: Neg Input Shapes: [(0, 8, 0, 80)] Output Shapes: [(0, 8, 0, 80)]
Node Name: /encoder_embed/conv/3/Exp Node OpType: Exp Input Shapes: [(0, 8, 0, 80)] Output Shapes: [(0, 8, 0, 80)]
Conv op infer shape faile x[1.8,80] -> Unsqueeze(aix 1) -> Conv(kernel 8,1,3,3)[1,8,x,80]
x输入形状:[1, 1, 8, 80] 卷积核形状:[8, 1, 3, 3] 填充(pads):[0, 1, 0, 1] 步幅(strides):[1, 1] 膨胀(dilations):[1, 1]