pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.04k stars 22k forks source link

Faster Pytorch dequantize() + matmul for quantized models #115985

Open mobicham opened 8 months ago

mobicham commented 8 months ago

Fast Pytorch dequantize() + matmul

I would like to open the discussion about faster inference with quantized models using pure Pytorch calls.

As you know, quantization is extremely important to run large models like LLMs with limited GPU memory. This is especially important for the open-source community. Making these models more accessible is a huge step forward.

The quality of quantized models is improving at a fast pace. With the recent releases of HQQ and Quip#. We have now reached a point where 4-bit and even 2-bit models work fairly well and this is surely gonna improve in the upcoming months.

Now the only issue: inference speed. While some libraries like llama.cpp are able to achieve faster inference with quantized models, a naive implementation in Pytorch can be 2x slower compared to fp16.

Alternatives

Custom Cuda/Triton kernels

Most of the methods like BNB/GPTQ/AWQ/Quip etc. implement custom CUDA/Triton kernels to do the dequantize() + matmul step. These kernels require a very low-level implementation and that can only be done efficiently in special cases. Moreover, these kernels are usually optimized for newer Nvidia chips like A100s and it's not possible to use this approach for other hardware like CPU, Intel's Arc, etc. In fact, some Triton kernels optimized for A100 may run slower on older GPUs.

Naive Pytorch Implementation

With a pure Pytorch implementation, this is basically implemented as follows:

def forward(self, x):
    W_est = self.dequantize()
    out   = torch.matmul(x, W_est.t())
    if(self.bias!=None): out += self.bias
    return out

The issue is that this is relatively slow due to the dequantize() call.

Torch Compile

This actually does make things faster with dynamo compile but:

Aten C++ Backend

I re-implemented the whole dequantization logic in Aten: https://github.com/mobiusml/hqq/blob/master/hqq/kernels/hqq_aten.cpp but the dynamo compile option is faster.

I found this new function that could be used in Aten but there's no documentation on how to use it: https://pytorch.org/cppdocs/api/function_namespaceat_1adeda9630914278ac02d7fd758da19e3d.html

Ideas

Efficient Native Pytorch bit-unpacking

Really the main speed bottleneck here is the bit-unpacking. I think we really need a native well-implemented bit-packing/unpacking function in Pytorch. The current hack to make low-bit tensors is to store them in uint8/uint32. The bit-unpacking operation requires extra copies and the operations when implemented via torch.cat for example are not optimized because they are done in serial. For 3-bit it's actually much worse because we need to use uint32 instead of uint8 for bit-packing which requires 10 mask+shfit+copy + another slicing operation because the weight shapes are not a multiple of 3.

In an ideal world, it would be awesome to have native Pytorch uint4, uint3, uint2 dtypes that work via casting:

t_2bit_as_8bit = torch.tensor([0, 1, 2, 3], dtype=torch.uint8)
t_2bit_as_2bit = t_2bit_as_8bit.to(torch.uint2) # actual 2-bit tensor

I know that quint4x2 already exists but, correct if I am wrong, it's not possible now to do uint8 <> quint4x2 casting. I was not able to use it directly.

I think by just having a more efficient implementation of bit-unpacking and use Dynamo for compilation we can potentially have something just as fast as fp16, if not faster. The implementation can be simply doing the same mask+shift+copy, but in a more efficient way. For example, I noticed that replacing the torch.cat operation by first allocating an empty array then copying the chunks can be quite faster.

Fused Dequantize() + matmul() kernels

As I have mentioned earlier, there's one here: https://pytorch.org/cppdocs/api/function_namespaceat_1adeda9630914278ac02d7fd758da19e3d.html I am curious to know if you plan to have something similar for 2 and 3 bit. Also, how much is it faster compared to fp16? This solution will not work for codebook-based approaches though.

Conclusion

I think it would be really great to open the discussion about this topic. Happy to know your roadmap regarding accelerating quantized models with Pytorch. Thanks!

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel

vadimkantorov commented 8 months ago

Related discussion on packing/unpacking bit tensors:

cpuhrsch commented 8 months ago

Also see https://github.com/pytorch-labs/ao cc @HDCharles @supriyar

jerryzh168 commented 8 months ago

thanks, we are creating a prototype for int4 tensor here: https://github.com/pytorch-labs/ao/pull/13 as the frontend for all the int4/int3/int2 etc. Tensors.

As for performance I think theoretically a combination of torch.compile codegen (e.g. for dequantize op, or also codegen the fused dequant + mm as well) and custom kernels (e.g. manually implement fused dequant + mm) should help

mobicham commented 8 months ago

@jerryzh168 thank you for your reply!

I see that the int4 implementation is based on Pytorch ops. Did you benchmark the speed of the bit-unpacking + torch.compile() ? I was doing something very similar (just using torch.catinstead oftorch.stack). This version is a bit faster: https://github.com/mobiusml/hqq/blob/master/hqq/core/bitpack.py#L35

jerryzh168 commented 8 months ago

@jerryzh168 thank you for your reply!

I see that the int4 implementation is based on Pytorch ops. Did you benchmark the speed of the bit-unpacking + torch.compile() ? I was doing something very similar (just using torch.catinstead oftorch.stack). This version is a bit faster: https://github.com/mobiusml/hqq/blob/master/hqq/core/bitpack.py#L35

oh I haven't get to performance benchmarking yet, right now I'm just trying to make the frontend work, I think we can split this into two steps: (1) make sure this uint4 tensor subclass works as a frontend for all flows we want to support (2) performance tuning for GPU flow or other flows, we can explore this after the PR is landed