Closed HanGuo97 closed 1 month ago
hi @HanGuo97 , thanks for your attention, would you mind provide the reproduce scripts?
Thanks for the response --- this is the script I used:
def prepare_bitblas_data(
m: int,
n: int,
k: int,
W_dtype: str,
group_size: int,
dtype: torch.dtype,
device: torch.device,
) -> Dict:
if dtype == torch.float16:
A_dtype = "float16"
else:
raise ValueError
if W_dtype not in ["uint4", "uint2", "nf4", "int4"]:
raise ValueError
matmul_config = bitblas.MatmulConfig(
M=m,
N=n,
K=k,
A_dtype=A_dtype,
W_dtype=W_dtype,
accum_dtype=A_dtype,
out_dtype=A_dtype,
layout="nt",
with_bias=False,
# configs for weight only quantization
group_size=group_size,
with_scaling=True,
with_zeros=True,
zeros_mode="original")
matmul = bitblas.Matmul(config=matmul_config)
g = int(k / group_size)
A = torch.randn(
(m, k),
dtype=dtype,
device=device)
B = torch.randint(
0, 7,
(n, k),
dtype=torch.int8,
device=device)
D = torch.empty(
(m, n),
dtype=dtype,
device=device)
S = torch.ones(
(n, g),
dtype=dtype,
device=device)
Q = matmul.transform_weight(B)
torch.cuda.synchronize()
return {
"A": A,
"Q": Q,
"D": D,
"S": S,
"fn": matmul,
}
data = prepare_bitblas_data(
8, 8192, 8192,
W_dtype="nf4",
group_size=128,
dtype=torch.float16,
device="cuda")
@HanGuo97 its a very interesting bug.using a static one (there's a bug in static 1<m<16, let me take a look tomorrow.. )
and actually a quick fix is generating dynamic m kernel (in this case, the m=8 will be padded to m=16 tensor core):
import torch
from typing import Dict
import bitblas
bitblas.set_log_level("DEBUG")
def prepare_bitblas_data(
m: int,
n: int,
k: int,
W_dtype: str,
group_size: int,
dtype: torch.dtype,
device: torch.device,
) -> Dict:
if dtype == torch.float16:
A_dtype = "float16"
else:
raise ValueError
if W_dtype not in ["uint4", "uint2", "nf4", "int4"]:
raise ValueError
matmul_config = bitblas.MatmulConfig(
N=n,
K=k,
A_dtype=A_dtype,
W_dtype=W_dtype,
accum_dtype=A_dtype,
out_dtype=A_dtype,
layout="nt",
with_bias=False,
# configs for weight only quantization
group_size=group_size,
with_scaling=True,
with_zeros=True,
zeros_mode="original")
matmul = bitblas.Matmul(config=matmul_config)
g = int(k / group_size)
A = torch.randn(
(m, k),
dtype=dtype,
device=device)
B = torch.randint(
0, 7,
(n, k),
dtype=torch.int8,
device=device)
D = torch.empty(
(m, n),
dtype=dtype,
device=device)
S = torch.ones(
(n, g),
dtype=dtype,
device=device)
Q = matmul.transform_weight(B)
torch.cuda.synchronize()
return {
"A": A,
"Q": Q,
"D": D,
"S": S,
"fn": matmul,
}
data = prepare_bitblas_data(
8, 8192, 8192,
W_dtype="uint4",
group_size=128,
dtype=torch.float16,
device="cuda")
I've fixed the static shape case as well, the upstream codes should work :) @HanGuo97
Thanks for the quick fix @LeiWang1999 !
When you say "upstream", do you mean the latest pip wheel released yesterday? I think I'm bumping into the same problem. Unfortunately, I'm not familiar with building from scratch if that's what you meant (looks like I need to build some of TVM dependencies as well).
On an unrelated note. For NF4 data type, there is another (small) problem:
337 if source_format == "nf":
--> 338 self.lut = torch.Tensor(([
339 -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
340 -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
341 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224,
342 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0
343 ]),
344 dtype=getattr(torch, self.A_dtype)).cuda()
345 else:
346 self.lut = None
TypeError: new() received an invalid combination of arguments - got (list, dtype=torch.dtype), but expected one of:
* (*, torch.device device)
didn't match because some of the keywords were incorrect: dtype
* (torch.Storage storage)
* (Tensor other)
* (tuple of ints size, *, torch.device device)
* (object data, *, torch.device device)
This is pretty easily fixable, but just want to point it out in case you missed it.
@HanGuo97 Thanks. I've released a new version on PyPI that resolves the tensor allocation issue. Please check it out.
pip install bitblas==0.0.1.dev5
The correctness of NF4 has been assessed:
import torch
import bitblas
bitblas.set_log_level("DEBUG")
M = 16
N = 1024
K = 1024
layout = "nt"
matmul_config = bitblas.MatmulConfig(
M=M,
N=N,
K=K,
A_dtype="float16",
W_dtype="nf4",
accum_dtype="float16",
out_dtype="float16",
layout=layout,
)
matmul = bitblas.Matmul(config=matmul_config, enable_tuning=False)
input_shape = (M, K)
weight_shape = (N, K) if layout == "nt" else (K, N)
output_shape = (M, N)
input_tensor = torch.rand(input_shape, dtype=torch.float16).cuda()
weight_tensor = torch.randint(0, 15, weight_shape, dtype=torch.int8).cuda()
bitblas_weight_tensor = matmul.transform_weight(weight_tensor)
bitblas_output_tensor = matmul(input_tensor, bitblas_weight_tensor)
print("BitBLAS output:", bitblas_output_tensor)
ref_weight_tensor = torch.empty_like(weight_tensor, dtype=torch.float16)
for i in range(N):
for j in range(K):
ref_weight_tensor[i, j] = matmul.lut[weight_tensor[i, j].item()]
ref_output_tensor = torch.matmul(input_tensor, ref_weight_tensor.t())
print("Ref output:", ref_output_tensor)
'''
BitBLAS output: tensor([[-29.9531, -19.3750, -19.3438, ..., -12.8594, -18.5000, -13.9219],
[-26.3438, -16.2344, -22.2656, ..., -12.9844, -22.2344, -13.8516],
[-27.7344, -14.4375, -19.3281, ..., -15.1797, -18.0469, -6.1758],
...,
[-17.9531, -14.4688, -21.1562, ..., -13.8828, -28.7188, -5.4453],
[-17.0312, -12.0469, -15.7344, ..., -9.2656, -19.2656, -9.5781],
[-21.9688, -18.0938, -16.7656, ..., -12.5781, -9.9688, -5.8359]],
device='cuda:0', dtype=torch.float16)
Ref output: tensor([[-29.8594, -19.3438, -19.4062, ..., -12.8438, -18.5469, -13.9219],
[-26.4062, -16.1719, -22.3438, ..., -12.9844, -22.3125, -13.9531],
[-27.8125, -14.3828, -19.3594, ..., -15.1328, -18.0312, -6.1367],
...,
[-17.9688, -14.5703, -21.3438, ..., -13.9219, -28.6875, -5.4727],
[-16.9219, -12.0703, -15.7266, ..., -9.2344, -19.4062, -9.5781],
[-22.0625, -18.0781, -16.7500, ..., -12.5469, -9.9375, -5.8398]],
device='cuda:0', dtype=torch.float16)
'''
BTW, by running the command pip install .
in the root directory, TVM should be automatically built in an ideal scenario if your build from scratch :).
Amazing, thanks for prompt response!
This version works now.
Hmm, for some reason the NF4 example will trigger something that got the Python processed killed.
EDIT: nvm, that was on me. Closing.
Hi, first of all, thanks for this amazing repo!
I noticed that I repeatedly hit the following compilation error (this was for NF4, but the same thing happened for other data types as well). Any thoughts on how to fix it?
Thanks in advance for your time!