ThanatosShinji / onnx-tool

A parser, editor and profiler tool for ONNX models.
https://pypi.org/project/onnx-tool/
MIT License
399 stars 52 forks source link

Questions about graph.shape_infer() #62

Closed hollyaoyaozi closed 10 months ago

hollyaoyaozi commented 10 months ago

Hi deer author,

In graph.shape_infer()'s implementation, there is a if-else branch to judge whether current node's shape_calc attribute is True (node.value_infer entered) or False (node.shape_infer entered).

My questions: 1) why only shape_infer or value_infer is implemented for some kinds of nodes ? For example, there is only shape_infer implemented for PadNode but it is not for value_infer。Does it mean that only one of these two infer functions is implemented is enough to run graph.shape_infer & graph.profile() ? 2) I found PadNode.shape_calc's values in different model's onnx are different (such as PadNode). Some are False, the others are True, wich may lead to value_infer's 'NotImplementedError' error. So what causes/determines shape_calc's value of a node?

Thanks

image

ThanatosShinji commented 10 months ago

Hi @hollyaoyaozi Here is the definition: shape_calc is True represents that this node's output tensor value will be used for shape inference.

In most cases, the input tensor shape can be used to infer the output tensor's shape, like Gemm nodes or Mul nodes. But for nodes like Reshape nodes, their output tensor shape can only be inferred by the value of its input tensor (shape tensor). So for the nodes that output this shape tensor should be marked as shape_calc=True.

The same nodes can indeed be marked as True or False on different models. It depends on the graph structure.

ThanatosShinji commented 10 months ago

There is another situation that may always need nodes' value_infer. It's constant_folding. If one node is in a constant subgraph, onnx_tool will try to pre-compute its result and save it to a constant tensor. In this case, all nodes in the constant subgraph need to call its value_infer function.

You can try to disable the constant_folding parameter if you meet any NotImplementedError.

hollyaoyaozi commented 10 months ago

Got it. Thank you very much !