mobiusml / hqq

Official implementation of Half-Quadratic Quantization (HQQ)
https://mobiusml.github.io/hqq_blog/
Apache License 2.0
709 stars 70 forks source link

Packing Format #32

Closed jeromeku closed 7 months ago

jeromeku commented 8 months ago

Great project and clean implementation!

Wondering what the motivation for packing 4-bit tensors (into u8) such that the first half of the tensor is interleaved with the second half.

More specifically, each byte contains a 4-bit val from the first half of the tensor (in the upper bits) and a 4-bit val from the second half of the tensor (in the lower bits), as opposed to packing consecutive values, such that unpacking could more easily yield contiguous segments of the original tensor.

mobicham commented 8 months ago

Thank you @jeromeku !

The reason is because that's the fastest way we could find to do bit-unpacking in pure Pytorch. The Pytorch backend allows more people to use the package and benefit from quantization.

The first implementation was pure Pytorch and was relying on torch.compile for speed-up. To do bit-unpacking with pure Pytorch, you need to reduce the number of operations and copy larger chunks as much as possible. That is why bit-unpacking with int32 via Pytorch is significantly slower than doing with uint8. Later we added bit-unpacking and dequant CUDA kernels and it was very easy port the same bit-unpacking logic. Since it was working fairly well (faster than torch.compile), we didn't think about improving it.

Do you think it would be faster in CUDA doing it as you described ? I am curious to know because, even with shared memory, I had a hard time speeding-up the dequant time compared to the vanilla simple implementation: https://github.com/mobiusml/hqq/blob/master/hqq/kernels/hqq_aten_cuda_kernel.cu#L204

jeromeku commented 8 months ago

@mobicham

The issue isn't so much the speed of a standalone dequant / unpack kernel -- though this can be sped up with some bit hacking tricks -- but moreso downstream ops.

Specifically, it would be useful to fuse dequant with gemv / gemm. Having the values packed such that threads are able to unpack and have access to contiguous elements in registers without additional shuffling would be useful.

mobicham commented 8 months ago

Yeah definitely, fused kernel would be great. Currently, dequant() + torch.matmul is about 12-18% slower compared to torch.matmul fp16, about the same speed as bitsandbytes that also use CUDA kernels. A fused kernel should actually be faster than fp16 because more data can be copied to the shared memory/cache.

The issue however is that this should be done properly otherwise you end-up with much slower fused kernels when running on older GPUs. For example, there are fused Triton kernels for HQQ for 4-bit using the bit-packing logic you mentioned (https://github.com/dvmazur/mixtral-offloading/blob/master/src/triton_kernels.py ) but this actually runs significantly slower than plain Pytorch + torch.compile on older GPUs. That's why we didn't include it as a separate backend in the project.

If you have any idea/suggestions on how to speed dequant or fused kernels, ideally in CUDA not Triton, please feel free to share any tips or contribute to the project!

jeromeku commented 8 months ago

@mobicham

It seems most quantization libraries adopt a specialized packing / preprocessed format in order to optimize inference speed (GPTQ / Marlin, AWQ).

The primary issue is that with quantized types that loading of these weights requires special processing to effectively leverage ldmatrix and mma.sync, the primary tensorcore primitives, which typically expect fp16 / bf16 types.

FasterTransformer introduced a number of innovations for mixed-type gemm targeted primarily at pre-Hopper architectures, which would also require preprocessing weights to a special format.

mobicham commented 8 months ago

I see, my CUDA knowledge is not as advanced but I understand what you say. It's not really an issue to add additional bit-packing methods so that the processing is aligned with tensorcore primitives. We can have a separate backend which would use the packing format you mentioned. Do you have by any chance some examples/resources you could point me to? I guess something like this , but as I mention, this Triton implementation is slower than Pytorch.

jeromeku commented 8 months ago

@mobicham

Let me give this some further thought and get back to you. Been investigating how to generalize across different quantization types and kernels.

mobicham commented 8 months ago

@jeromeku these guys use a similar packing logic (uint8 storing 2 4-bit values, with a step of 2): , then they call a cutlass gemm kernel with 4-bit/4-bit-packed tensors. The paper claims 4x speed-up, wondering if the same is possible but using grouping and asymmetric quantization (the kernel seems to with scales only): https://github.com/spcl/QuaRot/blob/main/quarot/functional/quantization.py#L42

jeromeku commented 8 months ago

@mobicham

Regarding QuaRot, the cutlass kernel they're using is an IGEMM (s4 x s4 => s32), so no upcasting to fp16 / bf16 (so incorporating zero point wouldn't work). To do efficient mixed-gemm where the quant types are first converted to higher type then scaled / shifted would require different gemm.

I am currently looking into a few things and will update:

mobicham commented 8 months ago

Thanks @jeromeku . Also that cutlass kernel doesn't support grouping. I am looking into ways to use Marlin, which doesn't use zero-point either and no grouping other than 128, but with some math magic I was able to make it work with HQQ's dequant logic on a toy example. Now the challenge is to get high-quality quantized models with such lower group-sizes, I have some promising results on Llama2-7B base, let's see how that generalizes to instruct models.

That sounds great, I did some attempts to make dequant() fast by caching stuff into the shared memory + manual vectorization, but it didn't it speed-up. Seems like there's already enough work for the threads, so with a step=4 it's actually slower than the simple kernel:


template <typename scalar_t>
__global__ void dequantize_4bit_u8_kernel(unsigned char* Wq_packed, scalar_t* scale, scalar_t* zero, scalar_t* W_r, int h, int w) { 
    const int step = 4;
    int i = (blockIdx.x*blockDim.x + threadIdx.x)*step;
    int n = h*w;
    if(i>=n) return;

    const int C1  = 0xF0;
    const int C2  = 0x0F;

    __shared__ unsigned char shared_Wq[step];
    __shared__ scalar_t shared_meta[2][step];
    __shared__ unsigned char C[2];

    #pragma unroll
    for (int s = 0; s < step; s++){
        int j             = (i + s) % w;
        shared_Wq[s]      = Wq_packed[i];
        shared_meta[0][s] = zero[j];
        shared_meta[1][s] = scale[j];
    }
    C[0] = C1; 
    C[1] = C2;
    __syncthreads();

    #pragma unroll
     for (int s = 0; s < step; s++){
        int j      = (i + s) % w;
        W_r[i]     = (scalar_t((shared_Wq[s] & C[0]) >> 4) - shared_meta[0][s])*shared_meta[1][s];   //First chunk
        W_r[i + n] = (scalar_t((shared_Wq[s] & C[1]))      - shared_meta[0][s])*shared_meta[1][s];   //Second chunk
     }
}

That's correct, for fine-tuning, you either need dequant() or transposed matmul.

Regarding 1-bit, someone did a nice CPU experiment, seems that the speed-up is real! https://github.com/catid/bitnet_cpu

Feel free to join us on Discord: https://discord.com/invite/VJcFz5TR we can discuss all of this!