microsoft / BitBLAS

BitBLAS is a library to support mixed-precision matrix multiplications, especially for quantized LLM deployment.
MIT License
190 stars 21 forks source link

NF4: Compilation Errors #37

Closed HanGuo97 closed 1 month ago

HanGuo97 commented 1 month ago

Hi, first of all, thanks for this amazing repo!

I noticed that I repeatedly hit the following compilation error (this was for NF4, but the same thing happened for other data types as well). Any thoughts on how to fix it?

error: too few arguments in function call
          main_kernel<<<dim3(1024, 1, 1), dim3(16, 8, 1), 0, 0>>>(A, B, D, LUT, Scale, Zeros);

Thanks in advance for your time!

LeiWang1999 commented 1 month ago

hi @HanGuo97 , thanks for your attention, would you mind provide the reproduce scripts?

HanGuo97 commented 1 month ago

Thanks for the response --- this is the script I used:


def prepare_bitblas_data(
    m: int,
    n: int,
    k: int,
    W_dtype: str,
    group_size: int,
    dtype: torch.dtype,
    device: torch.device,
) -> Dict:
    if dtype == torch.float16:
        A_dtype = "float16"
    else:
        raise ValueError
    if W_dtype not in ["uint4", "uint2", "nf4", "int4"]:
        raise ValueError

    matmul_config = bitblas.MatmulConfig(
        M=m,
        N=n,
        K=k,
        A_dtype=A_dtype,
        W_dtype=W_dtype,
        accum_dtype=A_dtype,
        out_dtype=A_dtype,
        layout="nt",
        with_bias=False,
        # configs for weight only quantization
        group_size=group_size,
        with_scaling=True,
        with_zeros=True,
        zeros_mode="original")
    matmul = bitblas.Matmul(config=matmul_config)

    g = int(k / group_size)

    A = torch.randn(
        (m, k),
        dtype=dtype,
        device=device)

    B = torch.randint(
        0, 7,
        (n, k),
        dtype=torch.int8,
        device=device)

    D = torch.empty(
        (m, n),
        dtype=dtype,
        device=device)

    S = torch.ones(
        (n, g),
        dtype=dtype,
        device=device)

    Q = matmul.transform_weight(B)

    torch.cuda.synchronize()
    return {
        "A": A,
        "Q": Q,
        "D": D,
        "S": S,
        "fn": matmul,
    }

data = prepare_bitblas_data(
    8, 8192, 8192,
    W_dtype="nf4",
    group_size=128,
    dtype=torch.float16,
    device="cuda")
LeiWang1999 commented 1 month ago

@HanGuo97 its a very interesting bug.using a static one (there's a bug in static 1<m<16, let me take a look tomorrow.. )

and actually a quick fix is generating dynamic m kernel (in this case, the m=8 will be padded to m=16 tensor core):

import torch
from typing import Dict
import bitblas
bitblas.set_log_level("DEBUG")

def prepare_bitblas_data(
    m: int,
    n: int,
    k: int,
    W_dtype: str,
    group_size: int,
    dtype: torch.dtype,
    device: torch.device,
) -> Dict:
    if dtype == torch.float16:
        A_dtype = "float16"
    else:
        raise ValueError
    if W_dtype not in ["uint4", "uint2", "nf4", "int4"]:
        raise ValueError

    matmul_config = bitblas.MatmulConfig(
        N=n,
        K=k,
        A_dtype=A_dtype,
        W_dtype=W_dtype,
        accum_dtype=A_dtype,
        out_dtype=A_dtype,
        layout="nt",
        with_bias=False,
        # configs for weight only quantization
        group_size=group_size,
        with_scaling=True,
        with_zeros=True,
        zeros_mode="original")
    matmul = bitblas.Matmul(config=matmul_config)

    g = int(k / group_size)

    A = torch.randn(
        (m, k),
        dtype=dtype,
        device=device)

    B = torch.randint(
        0, 7,
        (n, k),
        dtype=torch.int8,
        device=device)

    D = torch.empty(
        (m, n),
        dtype=dtype,
        device=device)

    S = torch.ones(
        (n, g),
        dtype=dtype,
        device=device)

    Q = matmul.transform_weight(B)

    torch.cuda.synchronize()
    return {
        "A": A,
        "Q": Q,
        "D": D,
        "S": S,
        "fn": matmul,
    }

data = prepare_bitblas_data(
    8, 8192, 8192,
    W_dtype="uint4",
    group_size=128,
    dtype=torch.float16,
    device="cuda")
LeiWang1999 commented 1 month ago

I've fixed the static shape case as well, the upstream codes should work :) @HanGuo97

HanGuo97 commented 1 month ago

Thanks for the quick fix @LeiWang1999 !

When you say "upstream", do you mean the latest pip wheel released yesterday? I think I'm bumping into the same problem. Unfortunately, I'm not familiar with building from scratch if that's what you meant (looks like I need to build some of TVM dependencies as well).

On an unrelated note. For NF4 data type, there is another (small) problem:

    337 if source_format == "nf":
--> 338     self.lut = torch.Tensor(([
    339         -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
    340         -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
    341         0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224,
    342         0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0
    343     ]),
    344                             dtype=getattr(torch, self.A_dtype)).cuda()
    345 else:
    346     self.lut = None

TypeError: new() received an invalid combination of arguments - got (list, dtype=torch.dtype), but expected one of:
 * (*, torch.device device)
      didn't match because some of the keywords were incorrect: dtype
 * (torch.Storage storage)
 * (Tensor other)
 * (tuple of ints size, *, torch.device device)
 * (object data, *, torch.device device)

This is pretty easily fixable, but just want to point it out in case you missed it.

LeiWang1999 commented 1 month ago

@HanGuo97 Thanks. I've released a new version on PyPI that resolves the tensor allocation issue. Please check it out.

pip install bitblas==0.0.1.dev5

The correctness of NF4 has been assessed:

import torch
import bitblas
bitblas.set_log_level("DEBUG")

M = 16
N = 1024
K = 1024
layout = "nt"
matmul_config = bitblas.MatmulConfig(
    M=M,
    N=N,
    K=K,
    A_dtype="float16",
    W_dtype="nf4",
    accum_dtype="float16",
    out_dtype="float16",
    layout=layout,
)
matmul = bitblas.Matmul(config=matmul_config, enable_tuning=False)

input_shape = (M, K)
weight_shape = (N, K) if layout == "nt" else (K, N)
output_shape = (M, N)

input_tensor = torch.rand(input_shape, dtype=torch.float16).cuda()
weight_tensor = torch.randint(0, 15, weight_shape, dtype=torch.int8).cuda()
bitblas_weight_tensor = matmul.transform_weight(weight_tensor)

bitblas_output_tensor = matmul(input_tensor, bitblas_weight_tensor)
print("BitBLAS output:", bitblas_output_tensor)

ref_weight_tensor = torch.empty_like(weight_tensor, dtype=torch.float16)

for i in range(N):
    for j in range(K):
        ref_weight_tensor[i, j] = matmul.lut[weight_tensor[i, j].item()]

ref_output_tensor = torch.matmul(input_tensor, ref_weight_tensor.t())
print("Ref output:", ref_output_tensor)
'''
BitBLAS output: tensor([[-29.9531, -19.3750, -19.3438,  ..., -12.8594, -18.5000, -13.9219],
        [-26.3438, -16.2344, -22.2656,  ..., -12.9844, -22.2344, -13.8516],
        [-27.7344, -14.4375, -19.3281,  ..., -15.1797, -18.0469,  -6.1758],
        ...,
        [-17.9531, -14.4688, -21.1562,  ..., -13.8828, -28.7188,  -5.4453],
        [-17.0312, -12.0469, -15.7344,  ...,  -9.2656, -19.2656,  -9.5781],
        [-21.9688, -18.0938, -16.7656,  ..., -12.5781,  -9.9688,  -5.8359]],
       device='cuda:0', dtype=torch.float16)
Ref output: tensor([[-29.8594, -19.3438, -19.4062,  ..., -12.8438, -18.5469, -13.9219],
        [-26.4062, -16.1719, -22.3438,  ..., -12.9844, -22.3125, -13.9531],
        [-27.8125, -14.3828, -19.3594,  ..., -15.1328, -18.0312,  -6.1367],
        ...,
        [-17.9688, -14.5703, -21.3438,  ..., -13.9219, -28.6875,  -5.4727],
        [-16.9219, -12.0703, -15.7266,  ...,  -9.2344, -19.4062,  -9.5781],
        [-22.0625, -18.0781, -16.7500,  ..., -12.5469,  -9.9375,  -5.8398]],
       device='cuda:0', dtype=torch.float16)
'''

BTW, by running the command pip install . in the root directory, TVM should be automatically built in an ideal scenario if your build from scratch :).

HanGuo97 commented 1 month ago

Amazing, thanks for prompt response!

This version works now.

HanGuo97 commented 1 month ago

Hmm, for some reason the NF4 example will trigger something that got the Python processed killed.

EDIT: nvm, that was on me. Closing.