microsoft / BitBLAS

BitBLAS is a library to support mixed-precision matrix multiplications, especially for quantized LLM deployment.
MIT License
190 stars 21 forks source link

Matrix multiplication is outputting unexpected values if W is transposed through PyTorch's `t()` function #33

Closed rokada-br closed 1 month ago

rokada-br commented 1 month ago

Greetings!

I've been trying to multiply small matrices $A \times W$ to learn how to use BitBLAS properly. From my understanding, layout="nt" tells me that W should be transposed.

So far, initializing the weights with the values already transposed ($W^t$) gives the correct result. However, if I initialize $W$ then transpose the tensor with W.t() or torch.transpose(W, 0, 1), the output is no longer correct.

config = bitblas.MatmulConfig(layout="nt", ...)
matmul = bitblas.Matmul(config) # int8 input/output, with int32 accumulation

a = torch.Tensor(...).cuda()
w = torch.Tensor(...).cuda()  # W 
wt = torch.Tensor(...).cuda()  # W with values already transposed

# This prints the correct answer
c = matmul(matmul.transform_input(a.to(torch.int8)), 
           matmul.transform_weight(wt.to(torch.int8)))
print(c)

# This prints a different answer
c = matmul(matmul.transform_input(a.to(torch.int8)), 
           matmul.transform_weight(w.t().to(torch.int8)))
print(c)

Is my understanding of how I should be using the library correct?

System Specs

Code Sample

Matrix multiplication with int8 values and int32 accumulation.

$$ \begin{bmatrix} 2 & 3 \end{bmatrix} \begin{bmatrix} 4 & 2 & 3 \ 2 & 1 & 2 \end{bmatrix} = \begin{bmatrix} 14 & 7 & 12 \end{bmatrix} $$

import bitblas
import torch

matmul_config = bitblas.MatmulConfig(
    M=1,  # M dimension
    N=3,  # N dimension
    K=2,  # K dimension
    A_dtype="int8",      # activation A dtype
    W_dtype="int8",      # weight W dtype
    accum_dtype="int32", # accumulation dtype
    out_dtype="int8",    # output dtype
    layout="nt",         # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
    with_bias=False,     # bias
    # configs for weight only quantization
    group_size=None,     # setting for grouped quantization
    with_scaling=False,  # setting for scaling factor
    with_zeros=False,    # setting for zeros
    zeros_mode=None,     # setting for how to calculating zeros
)

matmul = bitblas.Matmul(config=matmul_config)

PyTorch's matmul gives the expected answer:

a = torch.Tensor([[2, 3]]).cuda()
wt = torch.Tensor([[4, 2], [2, 1], [3, 2]]).cuda()
print(torch.matmul(a, wt.t()))
tensor([[14.,  7., 12.]], device='cuda:0')

Likewise, using BitBLAS with int8 gives the correct answer:

c = matmul(matmul.transform_input(a.to(torch.int8)), 
           matmul.transform_weight(wt.to(torch.int8)))
print(c)
tensor([[14,  7, 12]], device='cuda:0', dtype=torch.int8)

However, if I initialize w with the values in their "natural" order then transpose afterwards, the output is no longer the same:

w = torch.Tensor([[4, 2, 3], [2, 1, 2]]).cuda()
print(torch.matmul(a, w))

c = matmul(matmul.transform_input(a.to(torch.int8)), 
           matmul.transform_weight(w.t().to(torch.int8)))
print(c)
tensor([[14.,  7., 12.]], device='cuda:0')
tensor([[14, 12,  8]], device='cuda:0', dtype=torch.int8)

w.t() and wt should be the same, unless there are some memory shenanigans I'm not aware of:

print(wt == w.t())
tensor([[True, True],
        [True, True],
        [True, True]], device='cuda:0')

Is it a known issue? Or am I missing something? Thanks in advance!

LeiWang1999 commented 1 month ago

Hi @rokada-br , Thanks for your attention!

Looks like we should make sure the tensor is contiguous if we use 3rd party kernels.

c = matmul(a.to(torch.int8), 
           matmul.transform_weight(w.t().contiguous().to(torch.int8)))
print(c)

And this works.

LeiWang1999 commented 1 month ago

Hi, @rokada-br , I've made a pull request to fix it!