Open jeromeku opened 1 week ago
@jeromeku Reply to your questions:
w4a8 GEMM
benchmark, you can try bench_w4a8.py
in my repo https://github.com/HandH1998/marlin/tree/w4a8. For the FastFP16toInt8
benchmark, I provide an old version GEMM code in gist https://gist.github.com/HandH1998/b96922e0a0ab7da769fd93e34ffb068a, which is the baseline using the traditional instruction converting from fp16 to int8. You can put it in https://github.com/HandH1998/marlin/tree/w4a8, and do the benchmark.INT4 to FP16, then to INT8
. For per-channel scale, we can split the per-channel scale with GEMM using s_a * A * W * s_w
, so there is no need to do the complicate conversion like per-group.@HandH1998
Many thanks for the response!
Do you have the script used to test against other methods? Especially interested in reproducing the results against QoQ.
Also can't seem to find the FastINT4toINT8
conversion function when converting from int4 -> int8.
@HandH1998
Many thanks for the response!
Do you have the script used to test against other methods? Especially interested in reproducing the results against QoQ.
Also can't seem to find the
FastINT4toINT8
conversion function when converting from int4 -> int8.
You can reproduce the QQQ results following the Readme.md
's Usage
. As for FastINT4toINT8
conversion, you can refer to our paper Section 3.3.1. Actually, it just performs a left shift by 4 bits to convert int4 to int8 in this line https://github.com/HandH1998/QQQ/blob/49f06e0b47c606ca2c5558ade0805b0609d57a8f/csrc/qqq_gemm.cu#L540.
@HandH1998 Are activations dynamic or static quantization in QQQ?( you only mentioned that it is per-token quantization)
@brisker dynamic quantization
@HandH1998 I noticed you have compared your accuracy with QServe, but QServe is w4a8 with kv4, and your QQQ seems to have fp16 kv-cache, is this comparison fair?
@brisker As QServe doesn't offer a precision of w4a8f16, we directly compare QQQ with QServe using w4a8kv4. On the other hand, QServe employs various techniques to mitigate the impact of kv4. According to their paper, SmoothAttention reduces perplexity by 0.05 without adding system overhead. Progressive group quantization further improves perplexity by an additional 0.02, with only a negligible increase in dequantization overhead. Lastly, activation-aware channel reordering enhances perplexity by 0.03.
As illustrated in the following figure, the ablation study shows kv4 only increases perplexity by 0.04 compared to kv8 with these techniques. As we know, kv8 can deliver performance almost identical to fp16 kv cache, so the impact of kv4 is negligible.
@HandH1998 The speedup of QQQ w4a8g128 compared to marlin w4a16g128 seem to be very limited, I think this may be due to the fp16 kvcache of QQQ. Any plan to try QQQ w4a8g128-kv8?
@HandH1998 The speedup of QQQ w4a8g128 compared to marlin w4a16g128 seem to be very limited, I think this may be due to the fp16 kvcache of QQQ. Any plan to try QQQ w4a8g128-kv8?
We think the speedup of QQQ w4a8g128 is limited to the high dtype conversion overhead between FP16 and INT8 as shown in the following picture. QQQ only focuses on the weight quantization, and we don't plan to develop a w4a8g128-kv8. Actually, it can increase the computing throughput of large batch size to replace kvfp16 with kv8, but is not effective for small batch size. If you want to try QQQ with low-bit kv cache, we recommend our vllm PR which provides fp8 kv cache.
@HandH1998 The speedup of QQQ w4a8g128 compared to marlin w4a16g128 seem to be very limited, I think this may be due to the fp16 kvcache of QQQ. Any plan to try QQQ w4a8g128-kv8?
Thank you for your advice! Currently, prefill speed is more essential for most inference cases, while KV cache quantization lifts decode speed. KV8 has now been well solved, and you are welcome to combine QQQ with KV cache quantization methods!
@AniZpZ
@HandH1998
In the figure of your paper, there is w8a8 inference speed. Is this w8a8 inference speed tested on vllm? Which version of vllm?
Besides, why is w8a8 even slower than fp16 in your figure?
@brisker We developed a new version based on this PR to support dynamic activation per-token quantization. We think the online activation quantization will introduce additional overhead, resulting in slower inference speed compared to FP16 at smaller batch sizes. However, as the batch size increases, the scenario becomes compute-bound, and w8a8 is likely to outperform other quantization methods.
@HandH1998 and the fp16 speed in the figure is the vllm-fp16 speed(already armed with paged attention or other accelerating methods), not huggingface-pytorch inference speed, right?
@HandH1998 and the fp16 speed in the figure is the vllm-fp16 speed(already armed with paged attention or other accelerating methods), not huggingface-pytorch inference speed, right?
Yes.
@HandH1998 @AniZpZ
kwargs = {"torch_dtype": torch.float16, "device_map": "auto", "attn_implementation": "eager"}
fp16_model = AutoModelForCausalLM.from_pretrained(
args.model_path, trust_remote_code=True, **kwargs
)
time1 = time.time()
output_ids = model.generate(**inputs, max_new_tokens=args.max_new_tokens) # model can be fp16 or w4a8 quantized model generated by QQQ
time2 = time.time()
print(f"decoding time: {time2-time1}")
But the w4a8 inference time is nearly double of that of fp16. Is there any bug in this repo? (w4a8 quantize Nan loss is also weird)
fp16 decoding time: 3.2025535106658936
w4a8 decoding time: 5.649582147598267
Great paper and thanks for open sourcing the code.
A couple questions: 1) Is the benchmarking code in section 4 of the paper available (
GEMM
,FastFP16toInt8
)? 2) In the per-groupW4A8
kernel, why is there a need for an additional channel-wise scale factor inFusedDequantQuant
? I.e., theInt4
weights are dequantized toFP16
using group-wise scale factors, then quantized toInt8
using an additional channel-wise scale then fed toInt8
GEMM. In contrast, in the channel-wiseW4A8
kernel, theInt4
weights are directly converted toInt8
then fed toInt8
GEMM.