ARM-software / armnn

Arm NN ML Software. The code here is a read-only mirror of https://review.mlplatform.org/admin/repos/ml/armnn
https://developer.arm.com/products/processors/machine-learning/arm-nn
MIT License
1.14k stars 307 forks source link

Gather(ND) dim error #756

Closed zxros10 closed 2 months ago

zxros10 commented 5 months ago

gather_test.tar.gz

Execute command: ./aarch64_build/tests/ExecuteNetwork -N -I 100 -c GpuAcc -m gather_test/gather_dim_test_float32.tflite --reuse-buffers --tflite-executor parser Warning: No input files provided, input tensors will be filled with 0s. Info: ArmNN v33.1.0 arm_release_ver of this libmali is 'g6p0-01eac0', rk_so_ver is '5'. Info: Initialization time: 23.86 ms. Error: Failed to parse operator #4 within subgraph #0 error: Operation has invalid output dimensions: 3 Output must be an (4 + 1 - 1) -D tensor at function ParseGather [/home/arm-user/source/armnn/src/armnnTfLiteParser/TfLiteParser.cpp:4786]

tracyn-arm commented 5 months ago

The Gather operator in the Compute Library used by Arm NN will build the output shape in a specific way:

The docs for arm_compute::misc::shape_calculator::compute_gather_shape() are:

    /** Calculate the gather output shape of a tensor
     *
     * @param[in] input_shape   Input tensor shape
     * @param[in] indices_shape Indices tensor shape. Only supports for 2d and 3d indices
     * @param[in] actual_axis   Axis to be used in the computation
     *
     * @note Let input_shape be (X,Y,Z) and indices shape (W,O,P) and axis 1
     *       the new shape is computed by replacing the axis in the input shape with
     *       the indice shape so the output shape will be (X,W,O,P,Z)
     *
     * @return the calculated shape
     */

In the failing case provided, we have:

     * @note Let input_shape be [1,40,20,4] and indices shape [1] and axis 3
     *       the new shape is computed by replacing the axis in the input shape with
     *       the indice shape so the output shape will be [1,40,20,1]

This results in the generated error where: [1,40,20] != [1,40,20,1] as Arm NN is conforming to the requirements of the library it uses and failing on the output tensor shape that was set in the model.