Closed jeromeku closed 7 months ago
Thank you @jeromeku !
The reason is because that's the fastest way we could find to do bit-unpacking in pure Pytorch. The Pytorch backend allows more people to use the package and benefit from quantization.
The first implementation was pure Pytorch and was relying on torch.compile
for speed-up. To do bit-unpacking with pure Pytorch, you need to reduce the number of operations and copy larger chunks as much as possible. That is why bit-unpacking with int32 via Pytorch is significantly slower than doing with uint8. Later we added bit-unpacking and dequant CUDA kernels and it was very easy port the same bit-unpacking logic. Since it was working fairly well (faster than torch.compile), we didn't think about improving it.
Do you think it would be faster in CUDA doing it as you described ? I am curious to know because, even with shared memory, I had a hard time speeding-up the dequant time compared to the vanilla simple implementation: https://github.com/mobiusml/hqq/blob/master/hqq/kernels/hqq_aten_cuda_kernel.cu#L204
@mobicham
The issue isn't so much the speed of a standalone dequant / unpack
kernel -- though this can be sped up with some bit hacking tricks -- but moreso downstream ops.
Specifically, it would be useful to fuse dequant
with gemv
/ gemm
. Having the values packed such that threads are able to unpack and have access to contiguous elements in registers without additional shuffling would be useful.
Yeah definitely, fused kernel would be great. Currently, dequant() + torch.matmul is about 12-18% slower compared to torch.matmul fp16, about the same speed as bitsandbytes that also use CUDA kernels. A fused kernel should actually be faster than fp16 because more data can be copied to the shared memory/cache.
The issue however is that this should be done properly otherwise you end-up with much slower fused kernels when running on older GPUs. For example, there are fused Triton kernels for HQQ for 4-bit using the bit-packing logic you mentioned (https://github.com/dvmazur/mixtral-offloading/blob/master/src/triton_kernels.py ) but this actually runs significantly slower than plain Pytorch + torch.compile on older GPUs. That's why we didn't include it as a separate backend in the project.
If you have any idea/suggestions on how to speed dequant or fused kernels, ideally in CUDA not Triton, please feel free to share any tips or contribute to the project!
@mobicham
It seems most quantization libraries adopt a specialized packing / preprocessed format in order to optimize inference speed (GPTQ / Marlin
, AWQ
).
The primary issue is that with quantized types that loading of these weights requires special processing to effectively leverage ldmatrix
and mma.sync
, the primary tensorcore primitives, which typically expect fp16 / bf16
types.
FasterTransformer introduced a number of innovations for mixed-type gemm targeted primarily at pre-Hopper
architectures, which would also require preprocessing weights to a special format.
I see, my CUDA knowledge is not as advanced but I understand what you say. It's not really an issue to add additional bit-packing methods so that the processing is aligned with tensorcore primitives. We can have a separate backend which would use the packing format you mentioned. Do you have by any chance some examples/resources you could point me to? I guess something like this , but as I mention, this Triton implementation is slower than Pytorch.
@mobicham
Let me give this some further thought and get back to you. Been investigating how to generalize across different quantization types and kernels.
@jeromeku these guys use a similar packing logic (uint8 storing 2 4-bit values, with a step of 2): , then they call a cutlass gemm kernel with 4-bit/4-bit-packed tensors. The paper claims 4x speed-up, wondering if the same is possible but using grouping and asymmetric quantization (the kernel seems to with scales only): https://github.com/spcl/QuaRot/blob/main/quarot/functional/quantization.py#L42
@mobicham
Regarding QuaRot
, the cutlass kernel they're using is an IGEMM (s4 x s4 => s32
), so no upcasting to fp16 / bf16
(so incorporating zero point wouldn't work). To do efficient mixed-gemm where the quant types are first converted to higher type then scaled / shifted would require different gemm.
I am currently looking into a few things and will update:
4b
types to 16b
, which will be useful for compute-bound scenarios where we can afford to separate the dequant from gemm (i.e., fine-tuning).gemm
mainloop, more for inference scenarios.CUDA
and triton
. Thanks @jeromeku . Also that cutlass kernel doesn't support grouping. I am looking into ways to use Marlin, which doesn't use zero-point either and no grouping other than 128, but with some math magic I was able to make it work with HQQ's dequant logic on a toy example. Now the challenge is to get high-quality quantized models with such lower group-sizes, I have some promising results on Llama2-7B base, let's see how that generalizes to instruct models.
That sounds great, I did some attempts to make dequant()
fast by caching stuff into the shared memory + manual vectorization, but it didn't it speed-up. Seems like there's already enough work for the threads, so with a step=4
it's actually slower than the simple kernel:
template <typename scalar_t>
__global__ void dequantize_4bit_u8_kernel(unsigned char* Wq_packed, scalar_t* scale, scalar_t* zero, scalar_t* W_r, int h, int w) {
const int step = 4;
int i = (blockIdx.x*blockDim.x + threadIdx.x)*step;
int n = h*w;
if(i>=n) return;
const int C1 = 0xF0;
const int C2 = 0x0F;
__shared__ unsigned char shared_Wq[step];
__shared__ scalar_t shared_meta[2][step];
__shared__ unsigned char C[2];
#pragma unroll
for (int s = 0; s < step; s++){
int j = (i + s) % w;
shared_Wq[s] = Wq_packed[i];
shared_meta[0][s] = zero[j];
shared_meta[1][s] = scale[j];
}
C[0] = C1;
C[1] = C2;
__syncthreads();
#pragma unroll
for (int s = 0; s < step; s++){
int j = (i + s) % w;
W_r[i] = (scalar_t((shared_Wq[s] & C[0]) >> 4) - shared_meta[0][s])*shared_meta[1][s]; //First chunk
W_r[i + n] = (scalar_t((shared_Wq[s] & C[1])) - shared_meta[0][s])*shared_meta[1][s]; //Second chunk
}
}
That's correct, for fine-tuning, you either need dequant()
or transposed matmul.
Regarding 1-bit, someone did a nice CPU experiment, seems that the speed-up is real! https://github.com/catid/bitnet_cpu
Feel free to join us on Discord: https://discord.com/invite/VJcFz5TR we can discuss all of this!
Great project and clean implementation!
Wondering what the motivation for packing
4-bit
tensors (intou8
) such that the first half of the tensor is interleaved with the second half.More specifically, each byte contains a
4-bit
val from the first half of the tensor (in the upper bits) and a4-bit
val from the second half of the tensor (in the lower bits), as opposed to packing consecutive values, such that unpacking could more easily yield contiguous segments of the original tensor.