Closed yd2102 closed 1 year ago
The NETranspose and NEReshaping are both relatively slow on Armv8 according to my own test. Same issue.
There appears to be fixed format GEMM kernels suitable for computing ABT, but I'm not sure such kernels are usable in this case. Is there example code that shows how I can compute ABT using fixed format Neon kernels?
@yd2102
I think the only practical example at present is within this oneDNN PR here - https://github.com/oneapi-src/oneDNN/pull/1590 (the changes to matmul, and the supporting acl_matmul_utils.cpp show the absorption of an NETranspose
of the "B" matrix with a re-order into the memory format expected for the fixed format kernels.
Output of 'strings libarm_compute.so | grep arm_compute_version': arm_compute_version=v23.02.1 Build options: {'Werror': '1', 'debug': '0', 'neon': '1', 'opencl': '0', 'os': 'linux', 'openmp': '1', 'cppthreads': '0', 'arch': 'armv8.2-a', 'multi_isa': '1', 'build': 'native'} Git hash=b'd8bf9b53752a4f573120cf51b31055de8b3c7d29'
Platform: AWS Graviton3 aarch64 (ARMv8.4-a)
Operating System: 23~22.04.1-Ubuntu
Problem description:
Hi,
I am experiencing low performance when trying to compute ABT where A and B are matrices of shapes [M, K] and [N, K] respectively. Such pattern of computation is very common in modern transformer-based ML models, so it is important that we compute this efficiently.
What I've found so far in ACL's repo is in order to compute ABT, we need to compute BT first, and then compute dot product of A and BT.
Based on the result from linux profiler, it shows that > 60% of time is spent on matrix transpose, which isn't expected because transpose is a much more lightweight operation than GEMM itself.
So the questions are: 1) Is there a more optimized NETranspose kernel in ACL other than "transpose_32bit_elements" that I can configure (my processor supports ARM SVE)? 2) I think an even more optimized approach is to handle ABT in GEMM's tiled kernel without having to compute transpose and GEMM separately. Does ACL support this kind of fused computation?
Thanks!
The linux profiler shows > 60% of time is spent on matrix transpose:
Here's the short version of my code to see the problem (I use single thread):