bitsandbytes-foundation / bitsandbytes

Accessible large language models via k-bit quantization for PyTorch.
https://huggingface.co/docs/bitsandbytes/main/en/index
MIT License
6.03k stars 606 forks source link

About fusion of **kdequantize kernel** and **simple bf16/fp16 matmul** #1319

Open Ther-nullptr opened 1 month ago

Ther-nullptr commented 1 month ago

Feature request

A fused CUDA kernel that combine the dequantize main weight step and matrix multiplication, in order to reduce the data on-chip/off-chip movement.

Motivation

I use profile tools to analyze the breakdown of QLoRA:

config: activation:(4,512,14336), weight:(4096,14336), precision:NF4/BF16, platform: NVIDIA A800 80GB PCIe

image

image

Notice that the quantize/dequantize process of main weight occupies near 30%~50% of the main matrix multiplication. Analyzing the computing process:

  1. load 4bit weight matrix from DRAM(HBM) to SRAM(shared memory). <kernel 1>
  2. dequantize the weight to fp16/bf16 on SRAM. <kernel 1>
  3. write back the weight to DRAM. <kernel 1>
  4. load the fp16/bf16 weight and activation to SRAM. <kernel 2>
  5. compute on SRAM. <kernel 2>
  6. write back the output to DRAM. <kernel 2>

So is it possible to fuse the kernels to act like that:

  1. load 4bit weight matrix from DRAM(HBM) and 16bit activation to SRAM(shared memory). <kernel 1>
  2. dequantize the weight to fp16/bf16 on SRAM. <kernel 1>
  3. compute on SRAM. <kernel 1>
  4. write back the output to DRAM. <kernel 1>

Thus we only have to launch 1 kernel, and save 1 time of 16bit weight load, 1 time of 16bit weight store.

Your contribution

I just observe this, and I want ask is this idea possible.

matthewdouglas commented 1 month ago

We do have a more optimal GEMV path for inference with batch size of 1, but otherwise your thought process here is sound. It should be possible, and I would suggest following along with a potential FLUTE integration in #1293.