ARM-software / ComputeLibrary

The Compute Library is a set of computer vision and machine learning functions optimised for both Arm CPUs and GPUs using SIMD technologies.
MIT License
2.75k stars 767 forks source link

linear+gelu fused operator is not supported in ACL #1083

Open snadampal opened 6 months ago

snadampal commented 6 months ago

Output of 'strings libarm_compute.so | grep arm_compute_version': arm_compute_version=v23.11 Build options: {'Werror': '0', 'debug': '0', 'neon': '1', 'opencl': '0', 'embed_kernels': '0', 'os': 'linux', 'arch': 'armv8a', 'build': 'native', 'multi_isa': '1', 'fixed_format_kernels': '1', 'openmp': '1', 'cppthreads': '0'} Git hash=b'add70ace1e57f65d1ae4d0cedaec6e4578cf87ff'

Platform: AWS c7g.16xl

Operating System: Ubuntu 22.04

Problem description: PyTorch2.0 introduced torch.compile() for the neural network compilation. One of the important techniques the Graph compilation employs is the operator fusion. To execute those compiled graphs efficiently, the platform need to support the fused operators. For example, for Bert base model (I think any transformer model) inner_product+relu,matmul+relu(or gelu or tanh) are commonly fused in the linear layer. The issue is ACL23.11 doesn't support the above mentioned operators, hence we are not able to take full advantage of the PyTorch Graph compilation optimizations on aarch64.

Steps to reproduce: When you run the below script, you can see that the fused operators are falling back to onednn 'c' reference kernels because ACL doesn't support them.

pip3 install torch==2.1.1 export DNNL_VERBOSE=1

import torch
from transformers import BertTokenizer, BertModel
import torch._inductor.config as config
config.cpp.weight_prepack=True
config.freezing=True
​
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased").eval()
​
text = "Test bert base torch.compile on aarch64 with ACL"
encoded_input = tokenizer(text, return_tensors='pt')
​
model.eval()
model = torch.compile(model)
​
with torch.set_grad_enabled(False):
        model(**encoded_input)

Note: On PyTorch main, I have disabled the operator fusion for aarch64 to be able to use at least the other optimizations from the compilation, here is the PR. So, please use PyTorch 2.1.1 to reproduce the issue.

morgolock commented 5 months ago

Hi @snadampal

Thanks for raising this. We will discuss the feature request with the team.

jondea commented 5 months ago

oneDNN shouldn't be falling back to the reference kernels (i.e. ref). acl_post_ops_t should try to fuse the operators in ACL and, if it isn't able to, it should fall back to calling ACL activation functions as a separate layer. It's slower than fusion because you have an extra store and load for each element, but it's an order of magnitude faster than the reference kernels. Let me know if there's any cases like this, the fix may be relatively simple.

Also, do we know the relative importance of different activations and data types? I haven't done any in depth analysis but for compute bound activations like gelu or tanh, there may not be much benefit to fusing them over having a separate activation layer. For the simpler memory bound activations, there should be a larger benefit. I think non-leaky relu (α = 1) is already fused into quite a few kernels, although as far as I know, leaky relu is not yet.

snadampal commented 3 months ago

Hi @jondea , torch compiled version of bert-base (the script I provided above) has got attr-post-ops:eltwise_gelu_erf post op which is not supported in ACL, hence falling back to c++ reference kernel for fp32, and is failing to create primitive for bf16 fast math mode (because there are no reference fastmath kernels) post-op init is returning unimplemented from this https://github.com/oneapi-src/oneDNN/blob/main/src/cpu/aarch64/acl_utils.cpp#L108

I'm not sure if the gap in ACL is only the fused kernel or even the individual kernels.

for fp32: onednn_verbose,primitive,exec,cpu,inner_product,ref:any,forward_training,src_f32::blocked:ab::f0 wei_f32::blocked:Ab8a::f0 bia_f32::blocked:a::f0 dst_f32::blocked:ab::f0,attr-scratchpad:user attr-post-ops:eltwise_gelu_erf ,,mb28ic768oc3072,17.3999

for bf16 fast math mode: RuntimeError: could not create a primitive descriptor for an inner product forward propagation primitive

jondea commented 3 months ago

Great, thanks for the reproducer. It looks like ACL does in fact have a GELU implementation, at least for NEON FP32 https://github.com/ARM-software/ComputeLibrary/blob/bc89a0b690200750040770bda0981f4a37b389c4/src/cpu/kernels/CpuActivationKernel.cpp#L107

It should be straightforward to hook this up here:

https://github.com/oneapi-src/oneDNN/blob/5bb1a8ee84cea8013be409f3474ad7e0c6b1e654/src/cpu/aarch64/acl_utils.cpp#L50

and it will automatically get picked up by the acl_post_ops_t inside acl_inner_product_t.

I have made an internal issue to take a look at this and get back to you. Things are quite busy at the moment so I'll need to get back to you on timescales.

jondea commented 3 months ago

@snadampal we now have a PR up for ACL GELU erf in oneDNN: https://github.com/oneapi-src/oneDNN/pull/1843. This should enable ACL primitives (including inner product) to be used when there's a GELU erf post op. This isn't a fusion in the sense that the activation happens inside the GEMM kernel, but it does mean that you can make use of the ACL accelerated kernels when there are post ops in oneDNN.

snadampal commented 3 months ago

thanks for the note, @jondea , I will take a look at it.

snadampal commented 2 months ago

Hi @jondea , how about the fusion support for the other primitive and post-ops combinations? Could you please add support for matmul + post-ops like gelu/relu/erf/tanh as well?

jondea commented 2 months ago

At the oneDNN level, we should automatically support combining matmul/conv/inner product with any binary or eltwise post op supported by the equivalent standalone ACL primitive. So matmul/conv/inner + gelu/relu/erf/tanh should accelerated by ACL in oneDNN (GELU went into v3.5).