HandH1998 / QQQ

QQQ is an innovative and hardware-optimized W4A8 quantization solution.
25 stars 2 forks source link

[QST] Speedup of GEMM #3

Open Hongbosherlock opened 3 days ago

Hongbosherlock commented 3 days ago

image

Why is W4A8 faster than W8A8? W4A8 needs some additional operations before performing INT8 GEMM Does W8A8 here refer to per-token+per-channel?

HandH1998 commented 3 days ago

@Hongbosherlock When the batch size is small, LLM inference becomes memory-bound. In this scenario, using w4a8 can reduce memory bandwidth usage by half compared to w8a8. w8a8 refers to the scenario without dequantization, focusing entirely on GEMM (General Matrix Multiply) time.

Hongbosherlock commented 3 days ago

@Hongbosherlock When the batch size is small, LLM inference becomes memory-bound. In this scenario, using w4a8 can reduce memory bandwidth usage by half compared to w8a8. w8a8 refers to the scenario without dequantization, focusing entirely on GEMM (General Matrix Multiply) time.

I thought the batch size is 1 , and GEMM is (M,K) x (K,N) in this pic. M refers to the number of input tokens .Mis changing and batch size is fix. (the pic seems like this) What's the batch size in this pic?

HandH1998 commented 3 days ago

@Hongbosherlock You are right. I should refer to num of input tokens.

Hongbosherlock commented 3 days ago

@Hongbosherlock You are right. I should refer to num of input tokens.

Is the pic about profiling single GEMM instead of LLM inference? I still don't understand why W4A8 is faster than W8A8 on GEMM (M,K) x (K,N).

HandH1998 commented 3 days ago

@Hongbosherlock You are right. I should refer to num of input tokens.

Is the pic about profiling single GEMM instead of LLM inference? I still don't understand why W4A8 is faster than W8A8 on GEMM (M,K) x (K,N).

Yes, it is profiling single GEMM. GEMM needs fetch weights from HBM, and W4A8 can save the memory access time by half compared to W8A8.

Hongbosherlock commented 2 days ago

GEMM needs fetch weights from HBM, and W4A8 can save the memory access time by half compared to W8A8.

In my opinion, the process is as follows:

Every two int4 elements are packed into one int8 element.

  • W4A8: (1)load activition(M,K) and weight (K/2,N) , (2) int4->int8 (3)int8 gemm
  • W8A8: (1)load activition(M,K) and weight (K,N) , (2) int8 gemm

But in the above picture, why does W4A8 perform so better when M is small. When the number of input tokens(M) is 1, the speedup of W4A8 is 3.5, and the speedup of W8A8 is 1.5. The former is more than twice the performance of the latter.

If there are any misunderstandings on my part, please point them out. Thanks!

HandH1998 commented 2 days ago

@Hongbosherlock You can use Nsight Compute to make a analysis. I have no idea why it is more than 2x speedup.

Hongbosherlock commented 2 days ago

you mean:

W4A8 can save the memory access time by half compared to W8A8.

W4A8:

time of W4A8 GEMM:t1+t2+t3

W8A8:

time of W4A8 GEMM:T1+T2

for int8 gemm:t3=T2

In ths pic (T1+T2)/(t1+t2+t3)=3.5/1.5>2 we can get: T1 > 2t1+2t2+t3, I don't understand how it could be ?

HandH1998 commented 2 days ago

you mean:

W4A8 can save the memory access time by half compared to W8A8.

W4A8:

  • t1: load activition(M,K) and weight (K/2,N)

  • t2: int4->int8

  • t3:int8 gemm

time of W4A8 GEMM:t1+t2+t3

W8A8:

  • T1:load activition(M,K) and weight (K,N)

  • T2:int8 gemm

time of W4A8 GEMM:T1+T2

for int8 gemm:t3=T2

In ths pic (T1+T2)/(t1+t2+t3)=3.5/1.5>2

we can get: T1 > 2t1+2t2+t3, I don't understand how it could be ?

calculation and memory access have overlap, the execution time cannot be calculated like the above in practice

Hongbosherlock commented 2 days ago

calculation and memory access have overlap, the execution time cannot be calculated like the above in practice

How is the performance of W8A8 tested, and which kernel is used?

HandH1998 commented 2 days ago

@Hongbosherlock We used cuBlas W4A8 GEMM https://github.com/AniZpZ/AutoSmoothQuant/tree/main/csrc/int8gemm.

Hongbosherlock commented 1 day ago

@Hongbosherlock We used cuBlas W4A8 GEMM https://github.com/AniZpZ/AutoSmoothQuant/tree/main/csrc/int8gemm.

How can I profile the w4a8 and w8a8 GEMM kernel like you do this pic?

HandH1998 commented 1 day ago

See https://github.com/HandH1998/QQQ/issues/2#issuecomment-2179921604.

brisker commented 1 day ago

See https://github.com/HandH1998/QQQ/issues/2#issuecomment-2179921604.

w8a8 quantized model can also be exported in QQQ repo? Simply modify wbits from 4 to 8?

HandH1998 commented 1 day ago

See https://github.com/HandH1998/QQQ/issues/2#issuecomment-2179921604.

w8a8 quantized model can also be exported in QQQ repo? Simply modify wbits from 4 to 8?

There are many diffs between w4a8 and w8a8, and you cannot modify wbits to 8 in QQQ.

brisker commented 17 hours ago

you mean:

W4A8 can save the memory access time by half compared to W8A8.

W4A8:

  • t1: load activition(M,K) and weight (K/2,N)
  • t2: int4->int8
  • t3:int8 gemm

time of W4A8 GEMM:t1+t2+t3 W8A8:

  • T1:load activition(M,K) and weight (K,N)
  • T2:int8 gemm

time of W4A8 GEMM:T1+T2 for int8 gemm:t3=T2 In ths pic (T1+T2)/(t1+t2+t3)=3.5/1.5>2 we can get: T1 > 2t1+2t2+t3, I don't understand how it could be ?

calculation and memory access have overlap, the execution time cannot be calculated like the above in practice

@HandH1998 Regarding the w4a8-gs128 setting, why does QQQ use the w4->fp16->w8 pipeline, but not the w4->w8 pipeline in QServe-w4a8-gs128, since w4->fp16->w8 seems to be slower? (In QServe w4a8-with-group, there is still no w4->fp16 process, identical to w4a8-no-group)