microsoft / BitBLAS

BitBLAS is a library to support mixed-precision matrix multiplications, especially for quantized LLM deployment.
MIT License
323 stars 29 forks source link

Further study Required for the performance gap when utilizing Block Reduction with Tensor Core #64

Closed LeiWang1999 closed 3 weeks ago

LeiWang1999 commented 2 months ago

PR #63 involves considerable efforts to implement Block reduction with TensorCore, which is expected to improve the performance of continuous batching with tensorcore. However, the performance is still poorer than trtllm batch dequantize kernel, a comprehensive study is required to understand the gap.

code to reproduce:

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from bitblas.utils.target_detector import auto_detect_nvidia_target
from bitblas import Matmul, MatmulConfig
import argparse
import bitblas
import tvm
from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy
from bitblas.base.roller.arch import CUDA
from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags
from bitblas.base.utils import apply_and_build
import time
from tvm import te, tir

bitblas.set_log_level("DEBUG")
# Initialize the parser  
parser = argparse.ArgumentParser(  
    description="Benchmark BitBLAS int4 on a specific target."  
)  

# Add arguments to the parser  
parser.add_argument(  
    "--target",  
    type=str,  
    default=auto_detect_nvidia_target(),  
    help="Specify the target device for benchmarking."  
)  
parser.add_argument(  
    "--group_size",  
    type=int,  
    default=None,  
    help="Group size for grouped quantization."  
)  
parser.add_argument(  
    "--A_dtype",  
    type=str,  
    default="float16",  
    choices=["float16", "float32", "float64", "int32", "int8"],  # Assuming these are the valid choices  
    help="Data type of activation A."  
)  
parser.add_argument(  
    "--W_dtype",  
    type=str,  
    default="uint4",  
    help="Data type of weight W."  
)  
parser.add_argument(  
    "--accum_dtype",  
    type=str,  
    default="float16",  
    help="Data type for accumulation."  
)  
parser.add_argument(  
    "--out_dtype",  
    type=str,  
    default="float16",  
    choices=["float16", "float32", "int32", "int8"],  # Assuming these are the valid choices  
    help="Data type for output."  
)  
parser.add_argument(  
    "--layout",  
    type=str,  
    default="nt",  
    choices=["nt", "nn"],  # Assuming these are the valid choices  
    help="Matrix layout, 'nt' for non-transpose A and transpose W."  
)  
parser.add_argument(  
    "--with_bias",  
    action="store_true",  
    help="Include bias in the benchmark."  
)  
parser.add_argument(  
    "--with_scaling",  
    action="store_true",  
    help="Include scaling factor in the quantization."  
)  
parser.add_argument(  
    "--with_zeros",  
    action="store_true",  
    help="Include zeros in the quantization."  
)  
parser.add_argument(  
    "--zeros_mode",  
    type=str,  
    default=None,  
    choices=["original", "rescale", "quantized"],  # Replace with actual modes if applicable  
    help="Specify the mode for calculating zeros."  
)
parser.add_argument(  
    "--propagate_a",  
    type=str,  
    default=True,  
    choices=["original", "rescale", "quantized"],  # Replace with actual modes if applicable  
    help="Specify the mode for calculating zeros."  
)
parser.add_argument(  
    "--propagate_b",  
    type=str,  
    default=True,  
    choices=["original", "rescale", "quantized"],  # Replace with actual modes if applicable  
    help="Specify the mode for calculating zeros."  
)  

# Parse the arguments  
args = parser.parse_args()  

# Assign arguments to variables  
target = args.target  
group_size = args.group_size  
A_dtype = args.A_dtype  
W_dtype = args.W_dtype  
accum_dtype = args.accum_dtype  
out_dtype = args.out_dtype  
layout = args.layout  
with_bias = args.with_bias  
group_size = args.group_size  
with_scaling = args.with_scaling  
with_zeros = args.with_zeros  
zeros_mode = args.zeros_mode 
propagate_a = args.propagate_a
propagate_b = args.propagate_b

test_shapes = [
    (MatmulConfig, Matmul, (16, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
]

benchmark_sets = []
benchmark_sets.extend(test_shapes)

# fmt:on

benchmark_results = {}
for config, operator, input_args in benchmark_sets:
    matmul_config = config(*input_args, propagate_a=True, propagate_b=True, fast_decoding=True)
    matmul = operator(matmul_config, target=target, enable_tuning=False)
    func = matmul.prim_func

    intrin_info = bitblas.base.roller.hint.IntrinInfo(
        in_dtype="float16",
        out_dtype="float16",
        trans_b=True,
        input_transform_kind=2,
        weight_transform_kind=2,
    )

    sch_normal = bitblas.gpu.MatmulTensorizationMMAWithDequantizeInfo().sch_shared_memory_prefetch_with_config(
        func,
        bitblas.base.roller.hint.Hint().from_dict({
            "warp": [16, 16],
            "block": [16, 64],
            "rstep": [128],
            "pipeline_stage": 2,
            "use_async": True,
            "intrin_info": intrin_info,
            "shared_scope": "shared",
            "vectorize": {
                "A": 8,
                "B": 8,
            },
            "rasterization_plan": bitblas.base.roller.Rasterization2DColumn(10)
        })
    )
    with tvm.transform.PassContext(config={"tir.use_async_copy": True, "tir.merge_static_smem": False, "cuda.kernels_output_dir": "./debug/bitblas_fp16xint4_fp16_pb_noscale_with_default"}):
        rt_mod = tvm.build(sch_normal.mod, target=matmul.target)
    time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, tvm.cuda(), number=10)
    profile_tensors = matmul.get_profile_tensors()
    latency = time_evaluator(*profile_tensors).mean * 1e3
    # print(rt_mod.imported_modules[0].get_source())
    print(f"Time cost is: {latency:.3f} ms")

    sch_reduce = bitblas.gpu.MatmulTensorizationMMAWithDequantizeInfo().sch_shared_memory_prefetch_with_config(
        func,
        bitblas.base.roller.hint.Hint().from_dict({
            "warp": [16, 16],
            "block": [16, 64],
            "rstep": [128],
            "pipeline_stage": 2,
            "use_async": True,
            "intrin_info": intrin_info,
            "shared_scope": "shared",
            "vectorize": {
                "A": 8,
                "B": 8,
            },
            "block_reduction_depth": 2,
            "rasterization_plan": bitblas.base.roller.Rasterization2DColumn(10)
        })
    )
    with tvm.transform.PassContext(config={"tir.use_async_copy": True, "tir.merge_static_smem": False, "cuda.kernels_output_dir": "./debug/bitblas_fp16xint4_fp16_pb_noscale_with_default"}):
        rt_mod = tvm.build(sch_reduce.mod, target=matmul.target)
    time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, tvm.cuda(), number=10)
    latency = time_evaluator(*profile_tensors).mean * 1e3
    # print(rt_mod.imported_modules[0].get_source())
    print(f"Time cost is: {latency:.3f} ms")
LeiWang1999 commented 2 months ago

Marlin appears to outperform both cutlass and bitBlas within small i4 shapes. Marlin's CUDA kernel design mandates setting the grid size to the number of streaming multiprocessors (such as 108), whereas cutlass and bitblas adjust it according to the tile size (such as 128). which makes around 50% wave waste.

We should dig further.

LeiWang1999 commented 1 month ago

The project flute provides a cute version of marlin, https://github.com/HanGuo97/flute , I think that's what we should learn and integrate.

LeiWang1999 commented 3 weeks ago

Conclusion: block Reduce can enhance the performance and we enable it for small shapes.