ThanatosShinji / onnx-tool

A parser, editor and profiler tool for ONNX models.
https://pypi.org/project/onnx-tool/
MIT License
399 stars 52 forks source link

about GatherElements operator #59

Closed Like2021 closed 11 months ago

Like2021 commented 11 months ago

Thanks for your excellent open source work, I guess I've run into a new problem.

I have a GatherElements operator, like:

img_v3_0264_116fa3e0-8e59-43c7-b9c5-0cd23ad2530g

This operator has two inputs. The input shapes are 1×195 and 1×196 respectively. When i use onnx_tool.share_infer(), I found that GatherElements node will call shape_infer directly, like:

class GatherElementsNode(Node):
    def shape_infer(self, intensors: List[Tensor], outtensors: List[Tensor]):
        outtensors[0].update_shape(intensors[0].get_shape())
        outtensors[0].update_dtype(intensors[0].dtype)

According to the code, it will directly default the output shape to be the same as the first Tensor of intensors, but for this GatherElements node, its Indices should be 1×195 input, then the output should be consistent with Indices, which is also 1×195, not 1×196.

The node attribute shape_calc will call the function shape_infer if it is False,

def shape_infer(self, inputs: {} = None):
    ...
    if node.shape_calc:
        node.value_infer(itensors, otensors)
    else:
        node.shape_infer(itensors, otensors)

so I guess shape_calc about GatherElements should be True?

Could you give some tips for this situation?

ThanatosShinji commented 11 months ago

It seems like a bug of GatherElements. The output shape should be outtensors[0].update_shape(intensors[1].get_shape()). Is your graph with tensor shapes (in your posted image) generated by onnx_tool?

ThanatosShinji commented 11 months ago

I've fixed the bug of GatherElements in this commit 1d170c782aeaa59771bd3d0c4f4e556769d38a54. Thanks for your feedback!

Like2021 commented 11 months ago

No, it is not generated by onnx_tool.

It seems like a bug of GatherElements. The output shape should be outtensors[0].update_shape(intensors[1].get_shape()). Is your graph with tensor shapes (in your posted image) generated by onnx_tool?