ARM-software / ComputeLibrary

The Compute Library is a set of computer vision and machine learning functions optimised for both Arm CPUs and GPUs using SIMD technologies.
MIT License
2.75k stars 767 forks source link

NEGEMMLowpMatrixMultiplyCore U8 performs worse than NEGEMM #1031

Closed yd2102 closed 1 year ago

yd2102 commented 1 year ago

Output of 'strings libarm_compute.so | grep arm_compute_version': arm_compute_version=v22.11 Build options: {'embed_kernels': '1', 'toolchain_prefix': 'aarch64-none-linux-gnu-', 'os': 'linux', 'opencl': '0', 'neon': '1', 'build_dir': 'armv8.2-a-sve-neon', 'asserts': '0', 'arch': 'armv8.2-a-sve', 'Werror': '1'} Git hash=unknown

Platform: ARMv8.4-a Neoverse-V1

Operating System: Ubuntu 22.04

Problem description: Hi,

I was benchmarking INT8 matrix multiplication using different thread settings. However, regardless of threading, the "NEGEMMLowpMatrixMultiplyCore" performance is much worse than the "NEGEMM" float32 kernel.

Here's a simple test program that I used:

#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/NEON/NEFunctions.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "utils/Utils.h"

#include <cstdlib>
#include <chrono>

using namespace arm_compute;

static const size_t M = 10;
static const size_t N = 768;
static const size_t K = 768;
static const size_t iterations = 10000;
static const float  alpha = 1.0f;
static const float  beta = 0.0f;

void benchmark(IFunction *kernel, const std::string& name, const int threads)
{
    printf("[%s] Using %d threads...\n", name.c_str(), threads);

    // Use specified number of threads
    NEScheduler::get().set_num_threads(threads);

    // Warm up kernel
    for (auto i = 0; i < 100; i++)
    {
        kernel->run();
    }

    auto total = threads * iterations;
    auto start = std::chrono::steady_clock::now();

    // Execute kernel
    for (auto i = 0; i < total; i++)
    {
        kernel->run();
    }

    auto stop = std::chrono::steady_clock::now();
    std::chrono::duration<double> diff = stop - start;
    double time = diff.count();
    double tp = 2 * M * N * K * total / time / 1e9;

    printf("[%s] %f ms/iter, %f op/s\n", name.c_str(), 1e3 * time / total, tp);
}

// Find min and max value in a float array
void find_min_max(int size, const float *data, float *min, float *max)
{
    *min = *max = data[0];
    for(int i = 0; i < size; i++)
    {
        const float val = data[i];
        *min            = std::min(*min, val);
        *max            = std::max(*max, val);
    }
}

// Return reasonable quantisation parameters to use for an array of floats
// based on min and max values
QuantizationInfo choose_quantization_params(float min, float max)
{
    // Extend the [min,max] interval to contain 0 so we can represent it exactly
    min = std::min(min, 0.f);
    max = std::max(max, 0.f);

    // Set the quantized min and max in float values
    const float qmin = 0;
    const float qmax = 255;

    // Determine the scale
    const float scale = (max - min) / (qmax - qmin);

    // Determine the zero-point; using affine equation val = (qval-zerop) * scale
    const float zero_point_real = qmin - min / scale;

    // But we need to nudge the zero_point to an integer (exact quantized value)
    std::uint8_t zero_point_nudged = 0;
    if(zero_point_real < qmin)
    {
        zero_point_nudged = qmin;
    }
    else if(zero_point_real > qmax)
    {
        zero_point_nudged = qmax;
    }
    else
    {
        zero_point_nudged = static_cast<std::uint8_t>(support::cpp11::round(zero_point_real));
    }

    QuantizationInfo qinfo = QuantizationInfo(scale, zero_point_nudged);
    return qinfo;
}

int main(int argc, char **argv)
{
    Tensor      src0;
    Tensor      src1;
    Tensor      dst0;
    NEGEMM      fgemm;

    // Populate tensor information
    src0.allocator()->init(TensorInfo(TensorShape(K, M), 1, DataType::F32));
    src1.allocator()->init(TensorInfo(TensorShape(N, K), 1, DataType::F32));
    dst0.allocator()->init(TensorInfo(TensorShape(N, M), 1, DataType::F32));

    // Configure kernel
    fgemm.configure(&src0, &src1, nullptr, &dst0, alpha, beta);

    // Allocate all tensors
    src0.allocator()->allocate();
    src1.allocator()->allocate();
    dst0.allocator()->allocate();
    auto *src0_ptr = reinterpret_cast<float *>(src0.buffer());
    auto *src1_ptr = reinterpret_cast<float *>(src1.buffer());
    auto *dst0_ptr = reinterpret_cast<float *>(dst0.buffer());

    // Initialize random inputs
    utils::fill_random_tensor(src0, -1.f, 1.f);
    utils::fill_random_tensor(src1, -1.f, 1.f);

    // Run benchmarking
    for (auto i : {1, 2, 4, 8})
    {
        benchmark(&fgemm, "NEGEMM", i);
    }

    float src0_min;
    float src0_max;
    float src1_min;
    float src1_max;
    find_min_max(M * K, src0_ptr, &src0_min, &src0_max);
    find_min_max(K * N, src1_ptr, &src1_min, &src1_max);

    // Get quantization parameters
    const QuantizationInfo src0_qinfo = choose_quantization_params(src0_min, src0_max);
    const QuantizationInfo src1_qinfo = choose_quantization_params(src1_min, src1_max);
    std::cout << "Matrix 1: min=" << src0_min << ", max=" << src0_max << ", ";
    std::cout << "QuantisationInfo(" << src0_qinfo.scale()[0] << ", " << src0_qinfo.offset()[0] << ")\n";
    std::cout << "Matrix 2: min=" << src1_min << ", max=" << src1_max << ", ";
    std::cout << "QuantisationInfo(" << src1_qinfo.scale()[0] << ", " << src1_qinfo.offset()[0] << ")\n";

    // Populate tensor information
    Tensor q_src0;
    Tensor q_src1;
    Tensor q_acc;
    q_src0.allocator()->init(TensorInfo(TensorShape(K, M), 1, DataType::QASYMM8, src0_qinfo));
    q_src1.allocator()->init(TensorInfo(TensorShape(N, K), 1, DataType::QASYMM8, src1_qinfo));
    q_acc.allocator()->init(TensorInfo(TensorShape(N, M), 1, DataType::S32));

    // Quantize inputs
    NEQuantizationLayer ql_src0;
    NEQuantizationLayer ql_src1;
    ql_src0.configure(&src0, &q_src0);
    ql_src1.configure(&src1, &q_src1);

    // Configure kernel
    NEGEMMLowpMatrixMultiplyCore qgemm;
    qgemm.configure(&q_src0, &q_src1, nullptr, &q_acc);

    // Allocate all tensors
    q_src0.allocator()->allocate();
    q_src1.allocator()->allocate();
    q_acc.allocator()->allocate();

    // Run quantization layers
    ql_src0.run();
    ql_src1.run();

    // Run benchmarking
    for (auto i : {1, 2, 4, 8})
    {
        benchmark(&qgemm, "NEGEMMLowpMatrixMultiplyCore", i);
    }

    return 0;
}

The test program was compiled using:

g++ test.cpp utils/Utils.cpp -I. -Iinclude -std=c++14 -larm_compute -larm_compute_core -o test

On my end, I can observe significant performance drop when using INT8 Gemm. Here's an example output from the program:

[NEGEMM] Using 1 threads...
[NEGEMM] 0.150591 ms/iter, 78.334470 op/s
[NEGEMM] Using 2 threads...
[NEGEMM] 0.093084 ms/iter, 126.729645 op/s
[NEGEMM] Using 4 threads...
[NEGEMM] 0.053842 ms/iter, 219.096056 op/s
[NEGEMM] Using 8 threads...
[NEGEMM] 0.045946 ms/iter, 256.743948 op/s
Matrix 1: min=-0.999493, max=0.999776, QuantisationInfo(0.00784027, 127)
Matrix 2: min=-0.999996, max=0.999996, QuantisationInfo(0.00784311, 127)
[NEGEMMLowpMatrixMultiplyCore] Using 1 threads...
[NEGEMMLowpMatrixMultiplyCore] 0.584267 ms/iter, 20.190213 op/s
[NEGEMMLowpMatrixMultiplyCore] Using 2 threads...
[NEGEMMLowpMatrixMultiplyCore] 0.425128 ms/iter, 27.748069 op/s
[NEGEMMLowpMatrixMultiplyCore] Using 4 threads...
[NEGEMMLowpMatrixMultiplyCore] 0.243374 ms/iter, 48.470617 op/s
[NEGEMMLowpMatrixMultiplyCore] Using 8 threads...
[NEGEMMLowpMatrixMultiplyCore] 0.261622 ms/iter, 45.089787 op/s

The profiler shows the following:

    86.70%  test     libarm_compute.so    [.] arm_compute::cpu::kernels::CpuGemmLowpMatrixMultiplyKernel::run_op
     1.65%  test     libc.so.6            [.] __GI___memcpy_simd
     1.53%  test     libarm_compute.so    [.] arm_gemm::sve_hybrid_fp32_mla_6x4VL
     1.10%  test     [unknown]            [k] 0xffffa7c5ea0dd06c
     1.00%  test     [unknown]            [k] 0xffffa7c5ea19841c
     0.74%  test     libarm_compute.so    [.] arm_compute::cpu::kernels::CpuGemmInterleave4x4Kernel::run_op
     0.53%  test     libarm_compute.so    [.] arm_compute::cpu::kernels::(anonymous namespace)::run_offset_contribution

I'm also curious why "sve_hybrid_s8s32_mmla_6x4VL" kernel wasn't leveraged in this case?

My impression of ACL is that its INT8 Gemm (used during quantization) should perform much better than float32 Gemm kernels. Please let me know if there's anything I missed.

Thanks!

GGGGxxxxxxxxr commented 1 year ago

same here. according to my several test models, WDSR(128X128), ALEXNET, and a one-dimensional signal model, the int8 operations won't give promising speed up..

morgolock commented 1 year ago

Hi @yd2102

It looks like your code is using NCHW layout, you will get better performance if you use NHWC.

You need to explicitly specify NHWC when initializing the tensor info as shown in https://github.com/ARM-software/ComputeLibrary/blob/main/examples/neon_permute.cpp#L40

You could also use ACL benchmark examples along with the instrumentation to analyze the kernels performance. If you build the library with benchmark_examples=1 then you can use the instruments to look into the graph example performance as shown below

root@acl_hikey_9:~/tmp/acl_mt# LD_LIBRARY_PATH=.:$LD_LIBRARY_PATH ./benchmark_graph_mobilenet_v2 --instruments=SCHEDULER_TIMER_MS --example_args='--layout=NHWC,--target=NEON,--fast-math,--type=QASYMM8'
Version = arm_compute_version=v0.0-unreleased Build options: {'standalone': '0', 'test_filter': 'ActivationLayer.cpp', 'opencl': '1', 'neon': '1', 'validation_tests': '1', 'examples': '0', 'debug': '0', 'arch': 'armv8a', 'benchmark_examples': '1'} Git hash=065d46b0042cb974063e915715f3295ca265e078
CommandLine = ./benchmark_graph_mobilenet_v2 --instruments=SCHEDULER_TIMER_MS --example_args=--layout=NHWC,--target=NEON,--fast-math,--type=QASYMM8 
CL_DEVICE_VERSION = OpenCL 2.0 not_released.51d50be.1502459415db9c37cbcbea279386cb09
Iterations = 1
Running [0] 'Examples/benchmark_graph_mobilenet_v2'
Threads : 1
Target : Neon
Data type : QASYMM8
Data layout : NHWC
Tuner enabled? : false
Cache enabled? : false
Tuner mode : Normal
Tuner file : 
MLGO file : 
Fast math enabled? : true

  SchedulerTimer/Conv/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #1:    AVG=4.2440 ms
  SchedulerTimer/Conv/CpuIm2ColKernel #0:    AVG=0.9520 ms
  SchedulerTimer/Conv_1/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #62:    AVG=2.9500 ms
  SchedulerTimer/Logits/AvgPool/CpuPool2dAssemblyWrapperKernel #63:    AVG=0.0550 ms
  SchedulerTimer/Logits/Conv2d_1c_1x1/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #64:    AVG=0.6760 ms
  SchedulerTimer/Predictions/Reshape/CpuReshapeKernel #65:    AVG=0.0590 ms
  SchedulerTimer/Predictions/Softmax/CpuLogits1DMaxKernel/neon_qu8_logits_1d_max #66:    AVG=0.0110 ms
  SchedulerTimer/Predictions/Softmax/CpuLogits1DSoftmaxKernel/neon_qu8_softmax_logits_1d #67:    AVG=0.0860 ms
  SchedulerTimer/expanded_conv/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst #2:    AVG=1.4960 ms
  SchedulerTimer/expanded_conv/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #3:    AVG=2.6010 ms
  SchedulerTimer/expanded_conv_1/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #4:    AVG=8.4050 ms
  SchedulerTimer/expanded_conv_1/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst #5:    AVG=1.0920 ms
  SchedulerTimer/expanded_conv_1/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #6:    AVG=1.6500 ms
  SchedulerTimer/expanded_conv_10/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #37:    AVG=0.9680 ms
  SchedulerTimer/expanded_conv_10/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst #38:    AVG=0.2350 ms
  SchedulerTimer/expanded_conv_10/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #39:    AVG=1.0050 ms
  SchedulerTimer/expanded_conv_11/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #40:    AVG=1.8990 ms
  SchedulerTimer/expanded_conv_11/add/CpuAddKernel/neon_qu8_add #43:    AVG=0.0560 ms
  SchedulerTimer/expanded_conv_11/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst #41:    AVG=0.3510 ms
  SchedulerTimer/expanded_conv_11/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #42:    AVG=1.4570 ms
  SchedulerTimer/expanded_conv_12/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #44:    AVG=1.8920 ms
  SchedulerTimer/expanded_conv_12/add/CpuAddKernel/neon_qu8_add #47:    AVG=0.0550 ms
  SchedulerTimer/expanded_conv_12/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst #45:    AVG=0.3530 ms
  SchedulerTimer/expanded_conv_12/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #46:    AVG=1.4610 ms
  SchedulerTimer/expanded_conv_13/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #48:    AVG=1.9520 ms
  SchedulerTimer/expanded_conv_13/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst #49:    AVG=0.1260 ms
  SchedulerTimer/expanded_conv_13/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #50:    AVG=0.6480 ms
  SchedulerTimer/expanded_conv_14/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #51:    AVG=1.2200 ms
  SchedulerTimer/expanded_conv_14/add/CpuAddKernel/neon_qu8_add #54:    AVG=0.0250 ms
  SchedulerTimer/expanded_conv_14/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst #52:    AVG=0.1890 ms
  SchedulerTimer/expanded_conv_14/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #53:    AVG=1.0470 ms
  SchedulerTimer/expanded_conv_15/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #55:    AVG=1.2330 ms
  SchedulerTimer/expanded_conv_15/add/CpuAddKernel/neon_qu8_add #58:    AVG=0.0250 ms
  SchedulerTimer/expanded_conv_15/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst #56:    AVG=0.1890 ms
  SchedulerTimer/expanded_conv_15/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #57:    AVG=1.0430 ms
  SchedulerTimer/expanded_conv_16/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #59:    AVG=1.2170 ms
  SchedulerTimer/expanded_conv_16/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst #60:    AVG=0.2090 ms
  SchedulerTimer/expanded_conv_16/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #61:    AVG=2.0660 ms
  SchedulerTimer/expanded_conv_2/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #7:    AVG=4.1830 ms
  SchedulerTimer/expanded_conv_2/add/CpuAddKernel/neon_qu8_add #10:    AVG=0.2540 ms
  SchedulerTimer/expanded_conv_2/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst #8:    AVG=1.4350 ms
  SchedulerTimer/expanded_conv_2/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #9:    AVG=2.0210 ms
  SchedulerTimer/expanded_conv_3/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #11:    AVG=4.0670 ms
  SchedulerTimer/expanded_conv_3/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst #12:    AVG=0.4420 ms
  SchedulerTimer/expanded_conv_3/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #13:    AVG=0.6780 ms
  SchedulerTimer/expanded_conv_4/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #14:    AVG=1.3340 ms
  SchedulerTimer/expanded_conv_4/add/CpuAddKernel/neon_qu8_add #17:    AVG=0.0780 ms
  SchedulerTimer/expanded_conv_4/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst #15:    AVG=0.4780 ms
  SchedulerTimer/expanded_conv_4/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #16:    AVG=0.7980 ms
  SchedulerTimer/expanded_conv_5/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #18:    AVG=1.3290 ms
  SchedulerTimer/expanded_conv_5/add/CpuAddKernel/neon_qu8_add #21:    AVG=0.0760 ms
  SchedulerTimer/expanded_conv_5/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst #19:    AVG=0.4720 ms
  SchedulerTimer/expanded_conv_5/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #20:    AVG=0.7620 ms
  SchedulerTimer/expanded_conv_6/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #22:    AVG=1.3390 ms
  SchedulerTimer/expanded_conv_6/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst #23:    AVG=0.1360 ms
  SchedulerTimer/expanded_conv_6/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #24:    AVG=0.3890 ms
  SchedulerTimer/expanded_conv_7/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #25:    AVG=0.9680 ms
  SchedulerTimer/expanded_conv_7/add/CpuAddKernel/neon_qu8_add #28:    AVG=0.0390 ms
  SchedulerTimer/expanded_conv_7/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst #26:    AVG=0.2390 ms
  SchedulerTimer/expanded_conv_7/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #27:    AVG=0.6820 ms
  SchedulerTimer/expanded_conv_8/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #29:    AVG=0.9730 ms
  SchedulerTimer/expanded_conv_8/add/CpuAddKernel/neon_qu8_add #32:    AVG=0.0380 ms
  SchedulerTimer/expanded_conv_8/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst #30:    AVG=0.2350 ms
  SchedulerTimer/expanded_conv_8/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #31:    AVG=0.6750 ms
  SchedulerTimer/expanded_conv_9/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #33:    AVG=0.9570 ms
  SchedulerTimer/expanded_conv_9/add/CpuAddKernel/neon_qu8_add #36:    AVG=0.0380 ms
  SchedulerTimer/expanded_conv_9/depthwise/depthwise/CpuDepthwiseConv2dAssemblyWrapperKernel/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst #34:    AVG=0.2380 ms
  SchedulerTimer/expanded_conv_9/project/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #35:    AVG=0.6730 ms
Executed 1 test(s) (1 passed, 0 expected failures, 0 failed, 0 crashed, 0 disabled) in 0 second(s)

Hope this helps.

yd2102 commented 1 year ago

It looks like your code is using NCHW layout, you will get better performance if you use NHWC.

You need to explicitly specify NHWC when initializing the tensor info as shown in https://github.com/ARM-> software/ComputeLibrary/blob/main/examples/neon_permute.cpp#L40

Hi @morgolock, I updated my test program to specify all tensors to be NHWC layout. The updated code is here.

But I still don't observe any performance improvement after the change.

morgolock commented 1 year ago

Hi @yd2102

What device are you running the test on?

I'd suggest using the benchmark example as shown above

./benchmark_graph_mobilenet_v2 --instruments=SCHEDULER_TIMER_MS --example_args='--layout=NHWC,--target=NEON,--type=QASYMM8'

And then compare against

./benchmark_graph_mobilenet_v2 --instruments=SCHEDULER_TIMER_MS --example_args='--layout=NHWC,--target=NEON,--type=F32

So then you could compare

SchedulerTimer/expanded_conv_10/Conv2D/CpuGemmAssemblyWrapperKernel/a64_gemm_u8_4x4 #37: AVG=0.9550 ms

Vs

SchedulerTimer/expanded_conv_10/expand/Conv2D+expanded_conv_10/expand/BatchNorm/CpuGemmAssemblyWrapperKernel/a64_hybrid_fp32_mla_6x16 #58: AVG=1.0590 ms

The kernels names showed by the instruments will change depending on the device you are using. If the device supports the dot product instruction you will see a good performance improvement in quantized gemm.

Hope this helps.

morgolock commented 1 year ago

Hi @yd2102

Please use QASYMM8_SIGNED instead of QASYMM8. You will get better performance. This was executed on Neoverse N1

[NEGEMM] Using 1 threads...
[NEGEMM] 0.298839 ms/iter, 39.474346 op/s
[NEGEMM] Using 2 threads...
[NEGEMM] 0.168447 ms/iter, 70.030914 op/s
[NEGEMM] Using 4 threads...
[NEGEMM] 0.089269 ms/iter, 132.145273 op/s
[NEGEMM] Using 8 threads...
[NEGEMM] 0.059230 ms/iter, 199.164304 op/s
Matrix 1: min=-0.99991, max=0.999979, QuantisationInfo(0.0078427, 127)
Matrix 2: min=-0.999998, max=0.999984, QuantisationInfo(0.00784307, 128)
[NEGEMMLowpMatrixMultiplyCore] Using 1 threads...
[NEGEMMLowpMatrixMultiplyCore] 0.073605 ms/iter, 160.268132 op/s
[NEGEMMLowpMatrixMultiplyCore] Using 2 threads...
[NEGEMMLowpMatrixMultiplyCore] 0.073852 ms/iter, 159.731630 op/s
[NEGEMMLowpMatrixMultiplyCore] Using 4 threads...
[NEGEMMLowpMatrixMultiplyCore] 0.065942 ms/iter, 178.891330 op/s
[NEGEMMLowpMatrixMultiplyCore] Using 8 threads...
[NEGEMMLowpMatrixMultiplyCore] 0.077265 ms/iter, 152.675485 op/s

Hope this helps.

yd2102 commented 1 year ago

Hi @morgolock, yes changing to QASYMM8_SIGNED invokes the optimized SVE kernel on my end (Neoverse V1):

[NEGEMMLowpMatrixMultiplyCore] Using 1 threads...
[NEGEMMLowpMatrixMultiplyCore] 0.025886 ms/iter, 455.702897 op/s

vs (before):

[NEGEMMLowpMatrixMultiplyCore] Using 1 threads...
[NEGEMMLowpMatrixMultiplyCore] 0.629730 ms/iter, 18.732586 op/s

I've identified what caused this. I can help submit a PR soon.

morgolock commented 1 year ago

Hi @yd2102

Thanks for the contribution!

Closing this now as it has been fixed with the patch: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9250