microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.11k stars 4.06k forks source link

[BUG] 8 bit quantized inference not as fast as hoped for? #4560

Open Epliz opened 11 months ago

Epliz commented 11 months ago

Describe the bug I am using the 4 bit post init quantization approach. I was hopping it would make inference faster in addition to saving memory. But it is not the case.

To Reproduce Quantize a model such as StarCoder.

Expected behavior inference being memory bound due to reading model weights, I thought it would be almost linearly faster when quantizing. But it is instead 2x slower. Is that because the weights are first dequantized, then torch linear is used? I guess it would be faster if it was fused together?

ds_report output Please run ds_report to give us details about your setup.

Screenshots If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

Docker context Are you using a specific docker image that you can share?

Additional context Add any other context about the problem here.

Epliz commented 11 months ago

Hi @RezaYazdaniAminabadi , Sorry for bothering you, but it seems like you might be a relevant person for this topic (basing this on your open PR touching quantization at https://github.com/microsoft/DeepSpeed/pull/4351 ). Could you indicate if there are plans to improve the QuantizedLinear layer to avoid the dequantization at https://github.com/microsoft/DeepSpeed/blob/f060407829f87da32a267a60d26d13a68dc11c61/deepspeed/inference/quantization/layers.py#L66 , at least for when processing a single new token (batch size 1 decoding phase)? If not in the mid-term (next 3 months or so), I could potentially try to find the time to contribute some CUDA kernel to do the GEMV with fused dequantization of the weights.