oneapi-src / oneDNN

oneAPI Deep Neural Network Library (oneDNN)
https://uxlfoundation.org
Apache License 2.0
3.59k stars 990 forks source link

[ARM] Support 8bit/4bit weights decompression for Matmul primitive #2081

Open dmitry-gorokhov opened 1 month ago

dmitry-gorokhov commented 1 month ago

Problem statement

LLM workloads oriented on best latency are memory bound. Inference speed is limited by model weights access through DDR. That’s why major optimization technique is weights compression (4bits weights compression might bring up-to 4 times better latency in comparison with bf16/fp16 weights).

Preferred solution

OneDNN already extended x64 brgemm Matmul primitive (8bit, 4bit) to support the following decompression math:

  1. Decompress block of weight in temp buffer (via brgemm_matmul_copy_b): w_fp = (w_compressed - zp)*scale.
  2. Call regular fp Matmul on the weight block.

Since floating point Brgemm Matmul is already implemented for aarch64 (at least with SVE) the proposal is to extended it to support compressed weights (in the same way it is done for x64).

The request is to support the following options:

  1. i4/u4/i8/u8 weights input + fp32/fp16/bf16 activations.
  2. additional input for scales (per output channel values for int8, grouped for int4). Data type: FP32/FP16
  3. optional zero point value (per output channel values for int8, grouped for int4). It can be equal to weights element type, but we can also convert to FP32/FP16 if impl prefers it.
mgouicem commented 3 days ago

@theComputeKid

theComputeKid commented 3 days ago

Thanks. I was expecting this to eventually be requested.