Open Hongbosherlock opened 3 days ago
@Hongbosherlock When the batch size is small, LLM inference becomes memory-bound. In this scenario, using w4a8 can reduce memory bandwidth usage by half compared to w8a8. w8a8 refers to the scenario without dequantization, focusing entirely on GEMM (General Matrix Multiply) time.
@Hongbosherlock When the batch size is small, LLM inference becomes memory-bound. In this scenario, using w4a8 can reduce memory bandwidth usage by half compared to w8a8. w8a8 refers to the scenario without dequantization, focusing entirely on GEMM (General Matrix Multiply) time.
I thought the batch size is 1
, and GEMM is (M,K) x (K,N)
in this pic. M
refers to the number of input tokens .M
is changing and batch size
is fix. (the pic seems like this)
What's the batch size in this pic?
@Hongbosherlock You are right. I should refer to num of input tokens
.
@Hongbosherlock You are right. I should refer to
num of input tokens
.
Is the pic about profiling single GEMM instead of LLM inference? I still don't understand why W4A8 is faster than W8A8 on GEMM (M,K) x (K,N).
@Hongbosherlock You are right. I should refer to
num of input tokens
.Is the pic about profiling single GEMM instead of LLM inference? I still don't understand why W4A8 is faster than W8A8 on GEMM (M,K) x (K,N).
Yes, it is profiling single GEMM. GEMM needs fetch weights from HBM, and W4A8 can save the memory access time by half compared to W8A8.
GEMM needs fetch weights from HBM, and W4A8 can save the memory access time by half compared to W8A8.
In my opinion, the process is as follows:
Every two int4 elements are packed into one int8 element.
- W4A8: (1)load activition(M,K) and weight (K/2,N) , (2) int4->int8 (3)int8 gemm
- W8A8: (1)load activition(M,K) and weight (K,N) , (2) int8 gemm
But in the above picture, why does W4A8 perform so better when M
is small. When the number of input tokens
(M) is 1, the speedup of W4A8
is 3.5
, and the speedup of W8A8
is 1.5
. The former is more than twice the performance of the latter.
If there are any misunderstandings on my part, please point them out. Thanks!
@Hongbosherlock You can use Nsight Compute to make a analysis. I have no idea why it is more than 2x speedup.
you mean:
W4A8 can save the memory access time by half compared to W8A8.
W4A8:
t1
: load activition(M,K) and weight (K/2,N) t2
: int4->int8 t3
:int8 gemm time of W4A8 GEMM:t1+t2+t3
W8A8:
T1
:load activition(M,K) and weight (K,N) T2
:int8 gemm time of W4A8 GEMM:T1+T2
for int8 gemm:t3=T2
In ths pic (T1+T2)/(t1+t2+t3)=3.5/1.5>2
we can get: T1 > 2t1+2t2+t3
, I don't understand how it could be ?
you mean:
W4A8 can save the memory access time by half compared to W8A8.
W4A8:
t1
: load activition(M,K) and weight (K/2,N)
t2
: int4->int8
t3
:int8 gemmtime of W4A8 GEMM:
t1+t2+t3
W8A8:
T1
:load activition(M,K) and weight (K,N)
T2
:int8 gemmtime of W4A8 GEMM:
T1+T2
for int8 gemm:
t3=T2
In ths pic
(T1+T2)/(t1+t2+t3)=3.5/1.5>2
we can get:
T1 > 2t1+2t2+t3
, I don't understand how it could be ?
calculation and memory access have overlap, the execution time cannot be calculated like the above in practice
calculation and memory access have overlap, the execution time cannot be calculated like the above in practice
How is the performance of W8A8 tested, and which kernel is used?
@Hongbosherlock We used cuBlas W4A8 GEMM https://github.com/AniZpZ/AutoSmoothQuant/tree/main/csrc/int8gemm.
@Hongbosherlock We used cuBlas W4A8 GEMM https://github.com/AniZpZ/AutoSmoothQuant/tree/main/csrc/int8gemm.
How can I profile the w4a8 and w8a8 GEMM kernel like you do this pic?
See https://github.com/HandH1998/QQQ/issues/2#issuecomment-2179921604.
w8a8 quantized model can also be exported in QQQ repo? Simply modify wbits from 4 to 8?
See https://github.com/HandH1998/QQQ/issues/2#issuecomment-2179921604.
w8a8 quantized model can also be exported in QQQ repo? Simply modify wbits from 4 to 8?
There are many diffs between w4a8 and w8a8, and you cannot modify wbits to 8 in QQQ.
you mean:
W4A8 can save the memory access time by half compared to W8A8.
W4A8:
t1
: load activition(M,K) and weight (K/2,N)t2
: int4->int8t3
:int8 gemmtime of W4A8 GEMM:
t1+t2+t3
W8A8:
T1
:load activition(M,K) and weight (K,N)T2
:int8 gemmtime of W4A8 GEMM:
T1+T2
for int8 gemm:t3=T2
In ths pic(T1+T2)/(t1+t2+t3)=3.5/1.5>2
we can get:T1 > 2t1+2t2+t3
, I don't understand how it could be ?calculation and memory access have overlap, the execution time cannot be calculated like the above in practice
@HandH1998 Regarding the w4a8-gs128 setting, why does QQQ use the w4->fp16->w8 pipeline, but not the w4->w8 pipeline in QServe-w4a8-gs128, since w4->fp16->w8 seems to be slower? (In QServe w4a8-with-group, there is still no w4->fp16 process, identical to w4a8-no-group)
Why is
W4A8
faster thanW8A8
?W4A8
needs some additional operations before performingINT8 GEMM
DoesW8A8
here refer toper-token+per-channel
?