oneapi-src / oneDNN

oneAPI Deep Neural Network Library (oneDNN)
https://uxlfoundation.org
Apache License 2.0
3.54k stars 974 forks source link

GEMM API for efficient LLM inference with W8A16 #1788

Open oleotiger opened 6 months ago

oleotiger commented 6 months ago

I want to perform inference on quantized LLAMA (W8A16) on ARM-v9 (with SVE) using oneDNN. The LLAMA weights are per-group quantized.

Based on my understanding, I need to prepack the weights to reduce the cost of repeated packing. However, packing will disrupt the arrangement of per-group quantization scales and shifts. I understand that dequantization needs to be fused with the kernel. If fused with packing, it's equivalent to storing another copy of the weights in FP16, essentially undoing the quantization.

I haven't figured out how to combine prepacking and per-group dequantization.

Which interface should I use for prepacking? SVE instructions can be 256-bit or 512-bit wide; how does oneDNN intelligently handle packing? After prepacking and saving the weights again, how do I fuse dequantization with the kernel during computation?

vpirogov commented 6 months ago

@oleotiger, we are working on enabling per-group quantization in oneDNN. You can find description of proposed design for fused weight decompression here. Implementation is not yet available for any platforms though. The only option for now is to decompress weights separately, as you indicated.

vpirogov commented 6 months ago

+@igorsafo

vpirogov commented 5 months ago

API and validation changes necessary to support W8A16 quantization landed to main and rls-v3.4 branches. Specifics is covered in GPT Quantization RFC.

+@jondea, @milpuz01 for additional comments on Arm specifics.