HandH1998 / QQQ

QQQ is an innovative and hardware-optimized W4A8 quantization solution.
22 stars 2 forks source link

[QST] Scale factors and benchmarks #2

Open jeromeku opened 1 week ago

jeromeku commented 1 week ago

Great paper and thanks for open sourcing the code.

A couple questions: 1) Is the benchmarking code in section 4 of the paper available (GEMM, FastFP16toInt8)? 2) In the per-group W4A8 kernel, why is there a need for an additional channel-wise scale factor in FusedDequantQuant? I.e., the Int4 weights are dequantized to FP16 using group-wise scale factors, then quantized to Int8 using an additional channel-wise scale then fed to Int8 GEMM. In contrast, in the channel-wise W4A8 kernel, the Int4 weights are directly converted to Int8 then fed to Int8 GEMM.

HandH1998 commented 1 week ago

@jeromeku Reply to your questions:

  1. For the w4a8 GEMM benchmark, you can try bench_w4a8.py in my repo https://github.com/HandH1998/marlin/tree/w4a8. For the FastFP16toInt8 benchmark, I provide an old version GEMM code in gist https://gist.github.com/HandH1998/b96922e0a0ab7da769fd93e34ffb068a, which is the baseline using the traditional instruction converting from fp16 to int8. You can put it in https://github.com/HandH1998/marlin/tree/w4a8, and do the benchmark.
  2. As there are multiple per-group scales in one channel of weight, which are not directly compatible with standard GEMM procedures, we have to do the conversion INT4 to FP16, then to INT8. For per-channel scale, we can split the per-channel scale with GEMM using s_a * A * W * s_w, so there is no need to do the complicate conversion like per-group.
jeromeku commented 1 week ago

@HandH1998

Many thanks for the response!

Do you have the script used to test against other methods? Especially interested in reproducing the results against QoQ.

Also can't seem to find the FastINT4toINT8 conversion function when converting from int4 -> int8.

HandH1998 commented 1 week ago

@HandH1998

Many thanks for the response!

Do you have the script used to test against other methods? Especially interested in reproducing the results against QoQ.

Also can't seem to find the FastINT4toINT8 conversion function when converting from int4 -> int8.

You can reproduce the QQQ results following the Readme.md's Usage. As for FastINT4toINT8 conversion, you can refer to our paper Section 3.3.1. Actually, it just performs a left shift by 4 bits to convert int4 to int8 in this line https://github.com/HandH1998/QQQ/blob/49f06e0b47c606ca2c5558ade0805b0609d57a8f/csrc/qqq_gemm.cu#L540.

brisker commented 4 days ago

@HandH1998 Are activations dynamic or static quantization in QQQ?( you only mentioned that it is per-token quantization)

HandH1998 commented 4 days ago

@brisker dynamic quantization

brisker commented 4 days ago

@HandH1998 I noticed you have compared your accuracy with QServe, but QServe is w4a8 with kv4, and your QQQ seems to have fp16 kv-cache, is this comparison fair?

HandH1998 commented 3 days ago

@brisker As QServe doesn't offer a precision of w4a8f16, we directly compare QQQ with QServe using w4a8kv4. On the other hand, QServe employs various techniques to mitigate the impact of kv4. According to their paper, SmoothAttention reduces perplexity by 0.05 without adding system overhead. Progressive group quantization further improves perplexity by an additional 0.02, with only a negligible increase in dequantization overhead. Lastly, activation-aware channel reordering enhances perplexity by 0.03. As illustrated in the following figure, the ablation study shows kv4 only increases perplexity by 0.04 compared to kv8 with these techniques. As we know, kv8 can deliver performance almost identical to fp16 kv cache, so the impact of kv4 is negligible. image

brisker commented 3 days ago

@HandH1998 The speedup of QQQ w4a8g128 compared to marlin w4a16g128 seem to be very limited, I think this may be due to the fp16 kvcache of QQQ. Any plan to try QQQ w4a8g128-kv8?

HandH1998 commented 3 days ago

@HandH1998 The speedup of QQQ w4a8g128 compared to marlin w4a16g128 seem to be very limited, I think this may be due to the fp16 kvcache of QQQ. Any plan to try QQQ w4a8g128-kv8?

We think the speedup of QQQ w4a8g128 is limited to the high dtype conversion overhead between FP16 and INT8 as shown in the following picture. QQQ only focuses on the weight quantization, and we don't plan to develop a w4a8g128-kv8. Actually, it can increase the computing throughput of large batch size to replace kvfp16 with kv8, but is not effective for small batch size. If you want to try QQQ with low-bit kv cache, we recommend our vllm PR which provides fp8 kv cache. image

AniZpZ commented 2 days ago

@HandH1998 The speedup of QQQ w4a8g128 compared to marlin w4a16g128 seem to be very limited, I think this may be due to the fp16 kvcache of QQQ. Any plan to try QQQ w4a8g128-kv8?

Thank you for your advice! Currently, prefill speed is more essential for most inference cases, while KV cache quantization lifts decode speed. KV8 has now been well solved, and you are welcome to combine QQQ with KV cache quantization methods!

brisker commented 2 days ago

@AniZpZ @HandH1998 In the figure of your paper, there is w8a8 inference speed. Is this w8a8 inference speed tested on vllm? Which version of vllm? Besides, why is w8a8 even slower than fp16 in your figure? image

HandH1998 commented 1 day ago

@brisker We developed a new version based on this PR to support dynamic activation per-token quantization. We think the online activation quantization will introduce additional overhead, resulting in slower inference speed compared to FP16 at smaller batch sizes. However, as the batch size increases, the scenario becomes compute-bound, and w8a8 is likely to outperform other quantization methods.

brisker commented 1 day ago

@HandH1998 and the fp16 speed in the figure is the vllm-fp16 speed(already armed with paged attention or other accelerating methods), not huggingface-pytorch inference speed, right?

HandH1998 commented 1 day ago

@HandH1998 and the fp16 speed in the figure is the vllm-fp16 speed(already armed with paged attention or other accelerating methods), not huggingface-pytorch inference speed, right?

Yes.

brisker commented 3 hours ago

@HandH1998 @AniZpZ

  1. In the PR you mentioned, how to save the corresponding w4a8-format-model to test w4a8-gemm? Is it identical to gptq-marlin w4 storage format?
  2. I use the default codes and configs in this repo, except comment these two lines (https://github.com/HandH1998/QQQ/blob/main/examples/quant_model.py#L70 and https://github.com/HandH1998/QQQ/blob/main/examples/quant_model.py#L61, otherwise NaN loss), and quant Llama2-7B and get the quantized models. And I use something like this to evaluate w4a8 and fp16 inference speed:
    kwargs = {"torch_dtype": torch.float16, "device_map": "auto", "attn_implementation": "eager"}
    fp16_model = AutoModelForCausalLM.from_pretrained(
            args.model_path, trust_remote_code=True, **kwargs
        ) 
    time1 = time.time()
    output_ids = model.generate(**inputs, max_new_tokens=args.max_new_tokens)  # model can be fp16 or w4a8 quantized model generated by QQQ
    time2 = time.time()
    print(f"decoding time: {time2-time1}")

But the w4a8 inference time is nearly double of that of fp16. Is there any bug in this repo? (w4a8 quantize Nan loss is also weird)

fp16 decoding time: 3.2025535106658936
w4a8 decoding time: 5.649582147598267