microsoft / BitBLAS

BitBLAS is a library to support mixed-precision matrix multiplications, especially for quantized LLM deployment.
MIT License
419 stars 34 forks source link

Matmul Output Hardware Mismatch #187

Open KeremTurgutlu opened 1 month ago

KeremTurgutlu commented 1 month ago

I am using a vLLM integration similar to this to run models with bitblas backend. When tested, our model generations were nonsense for most of the prompts on a A100. However, the same model works just fine with a A6000.

To debug it further I saved the intermediate activations from the vLLM model for the prompt that fails and found out that the activations after the first QKV layer have NaN values. Which can be seen below:

A100 activations

Screenshot 2024-09-18 at 2 29 01 PM

A6000 activations

Same activations are fine on a A6000:

Screenshot 2024-09-18 at 2 29 21 PM

To compare A100 against A6000 and to also provide a minimal repro, I saved the input tensor that is passed to the QKV layer (which fails on A100 but succeeds on A6000), as well as the bitblas compatible quantized weights of the QKV layer, and the expected output from the A6000.

import bitblas
from bitblas.cache import global_operator_cache, get_database_path
from bitblas.module import auto_detect_nvidia_target, BITBLAS_DATABASE_PATH

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

BITBLAS_TARGET = auto_detect_nvidia_target()
BITBLAS_DATABASE_PATH = "/workspace/.cache/bitblas"

def _get_or_create_bitblas_operator(config):
    if global_operator_cache.size() == 0:
        global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET)

    bitblas_matmul = global_operator_cache.get(config)
    if bitblas_matmul is None:
        # should disable tuning for the first time because we may require loading bitblas operator from database.
        bitblas_matmul = bitblas.Matmul(config) # default tuning is topk=20
        # bitblas_matmul.hardware_aware_finetune(topk=20)
        global_operator_cache.add(config, bitblas_matmul)
        global_operator_cache.save_into_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
        print("BitBLAS Tuning done, appended operator to global_operator_cache.")
    else:
        print("BitBLAS Operator found in global_operator_cache.")
    return bitblas_matmul

# load qweight, zeros, and scales
bitblas_weights = torch.load("/workspace/data/single_input_qkvproj_weights_a6000.pt")
qweight = bitblas_weights['qweight'].cuda()
scales = bitblas_weights['scales'].cuda()
zeros = bitblas_weights['zeros'].cuda()

# layernorm output which is passed to QKV
ln_act = torch.load("/workspace/data/single_input_ln0_act_a6000.pt").cuda()

# init matmul engine
K,N = 8192,10240
bitblas_dtype = torch.float16
GROUPSIZE = 128
BITBLAS_OPT_M = [1, 16, 32, 64, 128, 256, 512]
NBITS = 4

matmul_config = bitblas.MatmulConfig(M=BITBLAS_OPT_M,
                                        N=N,
                                        K=K,
                                        A_dtype="bfloat16" if bitblas_dtype == torch.bfloat16 else "float16",  
                                        W_dtype={4:"uint4",2:"uint2"}[NBITS],
                                        accum_dtype="float32" if bitblas_dtype == torch.bfloat16 else "float16",  
                                        out_dtype="float16",  
                                        layout="nt",  
                                        with_bias=False, 
                                        group_size=GROUPSIZE,
                                        with_scaling=True,  
                                        with_zeros=True,  
                                        zeros_mode="original",  
                                        #fast_decoding=True,
                                    )
matmul_eng = _get_or_create_bitblas_operator(matmul_config)     

# matmul on A100
out = matmul_eng(ln_act, qweight, scales, zeros)

# this passes now
assert not out.isnan().any().item()

# load expected output from the successful A6000.
expected_output = torch.load("/workspace/data/single_input_bitblas_output_a6000.pt")

# compute relative % diff between outputs
eps = 1e-4
out = out.cpu()
rel_diff = (out - expected_output).abs() / (expected_output.abs() + eps)
rel_diff.min(), rel_diff.max() # Note the very high max difference: (tensor(0.),  tensor(27696.))

# relative difference heatmap
import matplotlib.pyplot as plt
plt.imshow((rel_diff+eps).log(), aspect='auto')
plt.colorbar()
plt.show()
plt.show()

There is a large difference especially to the right side of the output matrix.

Screenshot 2024-09-18 at 1 42 41 PM

Do you have any explanation for this and how to find the root cause of this mismatch? Thanks!

Note: I also tried running the model with (dtype=bfloat16, accum dtype=fp32 and out=fp16) but still generations were nonsensical(such as repeated parantheses) probably due to NaNs.

KeremTurgutlu commented 1 month ago

repro_bitblas.zip - input tensor, qkv quantized weights and expected output can be found in the zip file.