PaddlePaddle / Paddle

PArallel Distributed Deep LEarning: Machine Learning Framework from Industrial Practice (『飞桨』核心框架,深度学习&机器学习高性能单机、分布式训练和跨平台部署)
http://www.paddlepaddle.org/
Apache License 2.0
21.66k stars 5.44k forks source link

[PHI] add int4 weight only quant kernel, add int4 weight only permute kernel #64094

Open yinfan98 opened 1 week ago

yinfan98 commented 1 week ago

PR Category

Others

PR Types

New features

Description

给paddle添加int4量化的kernel和int4量化进行permute的kernel。

TL;DR

支持了一个GPU kernel,它能做int4 weight only量化的工作。并且能支持weight_only_linear (同时也能和反量化接口对齐,如果你想单纯做量化反量化看看。你可以这么执行代码)

import paddle
x = paddle.randn(shape=[4096, 2048], dtype=paddle.float16)
qt, scale = paddle.nn.quant.weight_quantize(x, algo='weight_only_int4')
## 啊 paddle暂时还不可以形状推导。 但是PR已经在合了
## view之前的shape应该是[1024, 4096],这个shape是做weight only linear用的。后续也可以加一个接口判断是否矩阵乘法来判断是否在c++侧reshape
qt = qt.view([2048, 2048])
x_dq = paddle.nn.quant.weight_dequantize(qt, scale, algo='weight_only_int4')

当然,weight only linear也是支持的

import paddle
from paddle.nn.quant import weight_only_linear, weight_quantize, weight_dequantize

x = paddle.rand(shape=(2, 4096), dtype='float16')

weight = paddle.randn(shape=(4096, 2048), dtype='float32')
weight = weight.astype('float16')

quant_weight, quant_scale = weight_quantize(x=weight, algo='weight_only_int4')
quant_out = weight_only_linear(x=quant_x, weight=quant_weight, weight_scale=quant_scale, weight_dtype="int4")
## 能和它大概对齐吧,毕竟int4量化的精度低的离谱 out = paddle.matmul(x=x, y=weight)

int4 weight only quant总结

参考CPU的实现,SM70以上kernel的实现分几个步骤:

  1. 按行进行pack(2int4pack成一个int8)
  2. permute_B_rows_for_mixed_gemm:排布列方向的元素
  3. subbyte_transpose:把列主序的weight变成行主序的,并且由按行进行pack转化成按列进行pack。
  4. interleave_column_major_tensor:每64个元素进行interleave
  5. add_bias_and_interleave_int4s_inplace:把int8转换成uint8(+8)

但是我们其实不需要这么复杂的实现,我们可以直接就按列进行pack。也能达到一样的效果。并且只需要两个kernel(加上量化需要三个kernel)。方法如下:

int4量化kernel

对于int4量化来说,我们分别实现了按行pack和按列pack。(为了让SM70版本的显卡也能正常工作QAQ) 对按列pack来说,它需要让两个int4pack成一个int8的数进行实现。在代码里,我们让上下两行组成一个int8的数,也就是按列进行的pack。

int4 permute kernel

对于int4量化,我们需要对输入数据进行重排来适配cutlass的快速反量化kernel。 在int4反量化端,我们观察反量化算子实现可以发现。最后所需的输出是:

0   1   8   9  16  17  24  25   2   3  10  11  18  19  26  27
4   5  12  13  20  21  28  29   6   7  14  15  22  23  30  31

参考cutlass的快速反量化实现。 int4快速反量化4个int8一组,能把int8的数据转换为fp16的。但它会改变数据的排布:

0 2 4 6 1 3 5 7 -> 0 1 2 3 4 5 6 7

则我们可以推得在快速反量化之前,我们需要的数据是

//  0   8  16  24   1   9  17  25   2  10  18  26   3  11  19  27
//  4  12  20  28   5  13  21  29   6  14  22  30   7  15  23  31

上面一组数看上去没有任何的规律,但是我们可以给它做一点小小的调整,调整成下面的形式,只需要一些简单的位运算即可

// 0 1 16 17 8 9 24 25 2 3 18 19 10 11 26 27
// 4 5 20 21 12 13 28 29 6 7 22 23 14 15 30 31

我们知道,两个int4 pack成了一个int8,我们也可以把上面的数调整成int8的index

0 8 4 12 1 9 5 13 2 10 6 14 3 11 7 15

那么从

0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 -> 0 8 4 12 1 9 5 13 2 10 6 14 3 11 7 15

的坐标为

0 4 8 12 2 6 10 14 1 5 9 13 3 7 11 15

得到这个新的permute_kk(代码里的变量,描述列之间的permute),可以通过int8的permute_kk做一点小小的改变 从int8 permute转换为int4 permute int8

0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15

可以把它变成

0 2 4 6 1 3 5 7 8 10 12 14 9 11 13 15

% 8 * 2

0 4 8 12 2 6 10 14 0 4 8 12 2 6 10 14
add 1 for 0 4 8 12 2 6 10 14 [0 4 8 12 2 6 10 14]

简单的位运算kernel(最后执行)

// (0 1) (16 17) (8 9) (24 25) (2 3) (18 19) (10 11) (26 27)
// (4 5) (20 21) (12 13) (28 29) (6 7) (22 23) (14 15) (30 31)

//  0   8  16  24   1   9  17  25   2  10  18  26   3  11  19  27
//  4  12  20  28   5  13  21  29   6  14  22  30   7  15  23  31

我们可以每四个数一组,然后02 13 之间做低四位和高四位的交换即可。

int4 row interleave

对于int8的case,代码在相邻的两行中,每64个元素进行交织。但是对于int4的情况。代码就会在相邻的四行中,每32个元素进行交织。所以在permute的处理时,写成了

int permute_index = permute_kk % 32 + permute_kk / 32 * 128 +
                        32 * (n_id % 4) + total_k * 4 * (n_id / 4);

这样也符合预期。(写着写着天都亮了zzz)

paddle-bot[bot] commented 1 week ago

你的PR提交成功,感谢你对开源项目的贡献! 请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。 Your PR has been submitted. Thanks for your contribution! Please wait for the result of CI firstly. See Paddle CI Manual for details.

paddle-ci-bot[bot] commented 15 hours ago

Sorry to inform you that 19619b4's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.