PaddlePaddle / Paddle2ONNX

ONNX Model Exporter for PaddlePaddle
Apache License 2.0
670 stars 157 forks source link

Test tool bug : onnxbase.py does not support "tensorlist" as input when testing operators #1272

Open jiuyuedeyu156 opened 3 weeks ago

jiuyuedeyu156 commented 3 weeks ago

描述 当测试算子时,测试工具不支持tensorlist作为输入。 报错原因 当我执行算子单测时,使用tensorlist形式的输入,即shape = [paddle.to_tensor(3), paddle.to_tensor(4)]作为网络输入

class Net3(paddle.nn.Layer):
    """
    simple Net
    """

    def __init__(self):
        super(Net3, self).__init__()

    def forward(self, shape):
        """
        forward
        """
        print(shape)
        x = paddle.empty(shape, dtype=paddle.float32)
        print(x)
        return x

def test_empty_11_3():
    """
    api: paddle.empty
    op version: 11
    """
    op = Net3()
    op.eval()
    obj = APIOnnx(op, 'empty', [11])
    shape = [paddle.to_tensor(3), paddle.to_tensor(4)]
    print(shape)
    obj.set_input_data("input_data", shape)
    obj.run()

报错详情

TypeError: float() argument must be a string or a real number, not 'list'
tests/onnxbase.py:218: TypeError`

报错位置 位于onnxbase.pyset_input_data()方法,具体位于下面代码的最后一行

def set_input_data(self, group_name, *args):
        """
        params dict tool
        """
        self.kwargs_dict[group_name] = args
        if isinstance(self.kwargs_dict[group_name][0], tuple):
            self.kwargs_dict[group_name] = self.kwargs_dict[group_name][0]
        i = 0
        for in_data in self.kwargs_dict[group_name]:
            if isinstance(in_data, list):
                for tensor_data in in_data:
                    self.input_dtype.append(tensor_data.dtype)
                    self.input_spec.append(
                        paddle.static.InputSpec(
                            shape=tensor_data.shape,
                            dtype=tensor_data.dtype,
                            name=str(i)))
                    if len(tensor_data.shape) == 0:
                        self.input_feed[str(i)] = np.array(
>                           float(in_data), dtype=dtype_map[in_data.dtype])

补充说明 在paddle上测试没问题 image

Zheng-Bicheng commented 3 weeks ago

这个估计要改一下单测部分的代码了,因为Paddle的输入是没办法直接用做ONNX的输入的,应该要多一个to_numpy的操作