HandH1998 / QQQ

QQQ is an innovative and hardware-optimized W4A8 quantization solution.
https://arxiv.org/pdf/2406.09904
58 stars 4 forks source link

Possibility of using different group size setting #9

Open NicoNico6 opened 1 month ago

NicoNico6 commented 1 month ago

Hi, thanks for your great work and the open decision. I am trying different quantization group size (128 to 64/32) by changing the default hyperparameter '''group_size''', but the GEMM results are Nan compared to group size 128.

Can you share some idea for this issue, how can I support various group size?

HandH1998 commented 3 weeks ago

@NicoNico6 The w4a8 GEMM kernel only supports group_size=128 for now. If you want to support more group size, you need to design new GEMM kernels, which may be a little complicated.

HandH1998 commented 3 weeks ago

@NicoNico6 The kernel requires group_size % thread_k == 0. As thread_k is either 128 or 64 (refer to the following code), the group_size should be a multiple of 128, such as 256.

if (thread_k == -1 || thread_n == -1) {
    if (prob_m <= 16) {
      // For small batchizes, better partioning is slightly more important than better compute utilization
      thread_k = 128;
      thread_n = 128;
    } else {
      thread_k = 64;
      thread_n = 256;
    }
  }
darrenearl commented 2 weeks ago

我在marlin核函数增加了以下代码: CALL_IF(1, 8, 8, 16) CALL_IF(1, 16, 4, 16) CALL_IF(2, 16, 4, 16) CALL_IF(3, 16, 4, 16) CALL_IF(4, 16, 4, 16) 然后再跑test_w4a8.py, 其中groupsize=256,但是还是报错 `FAIL: test_groups (main.Test)

Traceback (most recent call last): File "/workspace/marlin/test_w4a8.py", line 178, in test_groups self.run_problem(m, n, k, *thread_shape, groupsize) File "/workspace/marlin/test_w4a8.py", line 87, in run_problem self.assertLess(torch.mean(torch.abs(D - D_ref)) / torch.mean(torch.abs(D_ref)), 0.003) AssertionError: tensor(1.4229, device='cuda:0', dtype=torch.float16) not less than 0.003


Ran 6 tests in 53.540s

FAILED (failures=1)` 请问代码还需要修改哪里吗?