mit-han-lab / llm-awq

[MLSys 2024 Best Paper Award] AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration
MIT License
2.38k stars 184 forks source link

AWQ kernel Issue #189

Open KThyo opened 4 months ago

KThyo commented 4 months ago

While conducting awq-related research, I discovered a problem with the new awq gemm kernel.

When I run this code,

from typing import Tuple
import random
import time

import torch

import awq_inference_engine

@torch.inference_mode()
def main() -> None:
    reshaped_x = torch.rand(512, 1024, dtype=torch.half)
    qweight = torch.randint(20000, (1024, 1024), dtype=torch.int16)
    scales = torch.rand(32, 4096, dtype=torch.half)
    qzeros = torch.rand(32, 4096, dtype=torch.half)
    print("reshaped_x", reshaped_x)
    print("qweight", qweight)
    print("scales", scales)
    print("qzeros", qzeros)

    def run_cuda_benchmark(num_iters: int) -> Tuple[torch.tensor, float]:

        start_time = time.perf_counter()

        for _ in range(num_iters):
            out = awq_inference_engine.gemm_forward_cuda_new(reshaped_x,
                qweight,
                scales,
                qzeros
            )
            print("out", out)
            print("out.shape", out.shape)
            print("--------------------------------------------------------")

        end_time = time.perf_counter()
        return out, (end_time - start_time) / num_iters

    run_benchmark = run_cuda_benchmark
    out, latency = run_benchmark(num_iters=5)

    print(f"Kernel running time: {latency * 1000000:.3f} us")

if __name__ == '__main__':
    main()

Input/qweight/scales/qzeros tensor's values are same. But, Output tensor's values are 0 and different values.

reshaped_x tensor([[0.0820, 0.4624, 0.9722,  ..., 0.1504, 0.9507, 0.2251],
        [0.6201, 0.7651, 0.4380,  ..., 0.8745, 0.8291, 0.9238],
        [0.9556, 0.1509, 0.7280,  ..., 0.4575, 0.9678, 0.6660],   
        ...,                       
        [0.6904, 0.5151, 0.0771,  ..., 0.0024, 0.3013, 0.4009],   
        [0.2988, 0.4912, 0.3652,  ..., 0.4590, 0.6748, 0.0239],
        [0.5913, 0.6523, 0.7417,  ..., 0.4956, 0.8652, 0.6777]],  
       dtype=torch.float16)                              
qweight tensor([[11579, 18300, 12826,  ..., 10524,  4347, 14946],
        [12009, 18653,  3196,  ..., 11657, 10654, 11258],
        [12776,  5834,  1425,  ..., 19792, 18499,  6622],
        ...,                                                                                                                                                                                                                                                                                                                                                    
        [10571, 18958,  5022,  ..., 10509,  9912, 10937],
        [16446,  3157,  2399,  ...,  2216,  4870, 17443],
        [ 9492, 15657,  3169,  ...,  8797, 10379, 16946]], dtype=torch.int16)
scales tensor([[0.2964, 0.2588, 0.4531,  ..., 0.1895, 0.4629, 0.5820],
        [0.5103, 0.0381, 0.6602,  ..., 0.1362, 0.9180, 0.0073],
        [0.1587, 0.5503, 0.1836,  ..., 0.2563, 0.1318, 0.2974],
        ...,
        [0.3149, 0.1064, 0.0391,  ..., 0.0156, 0.7944, 0.6191],
        [0.2319, 0.6514, 0.1348,  ..., 0.4155, 0.5620, 0.3589],
        [0.1416, 0.5166, 0.9712,  ..., 0.1255, 0.8154, 0.7256]],
       dtype=torch.float16)
qzeros tensor([[0.2437, 0.3042, 0.8623,  ..., 0.8984, 0.6992, 0.2129],
        [0.9609, 0.0703, 0.0864,  ..., 0.2197, 0.9585, 0.5664],
        [0.9497, 0.1504, 0.6274,  ..., 0.4927, 0.8916, 0.7173],
        ...,
        [0.3921, 0.6421, 0.9507,  ..., 0.4248, 0.4219, 0.2622],
        [0.6021, 0.9248, 0.4521,  ..., 0.2451, 0.5024, 0.8091],
        [0.3140, 0.4023, 0.8833,  ..., 0.3560, 0.8906, 0.0376]],
       dtype=torch.float16)
--------------------------------------------------------
out tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float16)
out.shape torch.Size([512, 4096])
--------------------------------------------------------
out tensor([[-6.0638e-02,  3.8225e+02,  2.9802e-07,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], dtype=torch.float16)
out.shape torch.Size([512, 4096])
--------------------------------------------------------
out tensor([[ 0.0048, -0.1835,     nan,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       dtype=torch.float16)
out.shape torch.Size([512, 4096])
--------------------------------------------------------
out tensor([[-8.7500e-01, -2.7847e-03,  9.3938e+01,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], dtype=torch.float16)
out.shape torch.Size([512, 4096])
--------------------------------------------------------
out tensor([[ 7.8125e-01, -3.5324e-03,  9.3938e+01,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], dtype=torch.float16)
out.shape torch.Size([512, 4096])
--------------------------------------------------------
Kernel running time: 52558.644 us

Where is the problem?

Thank you.