microsoft / BitBLAS

BitBLAS is a library to support mixed-precision matrix multiplications, especially for quantized LLM deployment.
MIT License
423 stars 34 forks source link

[TL] Wrap TL Kernel with Scheduler #199

Closed LeiWang1999 closed 1 month ago

LeiWang1999 commented 1 month ago

This pull request introduces several significant updates and additions to the bitblas library, particularly focusing on matrix multiplication (matmul) operations. Key changes include the implementation of a new matmul function for dequantized weights, the addition of various matmul schedulers, and the inclusion of comprehensive testing for these schedulers.

New Features and Implementations:

Testing Enhancements:

Code Quality Improvements:

These changes collectively enhance the functionality, maintainability, and reliability of the bitblas library, particularly in the context of matrix multiplication operations

@dataclass
class MatmulScheduler:

    # OP Related Config
    M: int
    N: int
    K: int
    trans_A: bool = False
    trans_B: bool = False
    dtypeAB: str = "float16"
    dtypeC: str = "float16"
    accum_dtype: str = "float16"

    # Default Tile Related Params
    block_M: int = 64
    block_N: int = 64
    block_K: int = 32
    num_stages: int = 2
    threads: int = 128
    enable_rasterization: bool = False  # Enhance L2 Locality

    def with_default_config(self):
        block_M = getattr(self, "block_M", 64)
        block_N = getattr(self, "block_N", 64)
        block_K = getattr(self, "block_K", 32)
        num_stages = getattr(self, "num_stages", 2)
        threads = getattr(self, "threads", 128)
        enable_rasterization = getattr(self, "enable_rasterization", False)

        return self.apply_config(
            block_M=block_M,
            block_N=block_N,
            block_K=block_K,
            num_stages=num_stages,
            threads=threads,
            enable_rasterization=enable_rasterization,
        )

    def apply_config(
        self,
        block_M=64,
        block_N=64,
        block_K=32,
        num_stages=2,
        threads=128,
        # Enhance L2 Locality
        enable_rasterization=False,
    ):
        M, N, K = self.M, self.N, self.K
        trans_A, trans_B = self.trans_A, self.trans_B
        dtypeAB, dtypeC, accum_dtype = self.dtypeAB, self.dtypeC, self.accum_dtype

        A_shape = (K, M) if trans_A else (M, K)
        B_shape = (N, K) if trans_B else (K, N)
        A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
        B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

        @T.prim_func
        def main(
                A: T.Buffer(A_shape, dtypeAB),
                B: T.Buffer(B_shape, dtypeAB),
                C: T.Buffer((M, N), dtypeC),
        ):
            with T.Kernel(
                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
                A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
                B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
                C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

                if enable_rasterization:
                    # rasterization factor
                    T.use_swizzle(10)

                T.clear(C_local)
                for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                    if trans_A:
                        T.copy(A[k * block_K, by * block_M], A_shared)
                    else:
                        T.copy(A[by * block_M, k * block_K], A_shared)
                    if trans_B:
                        T.copy(B[bx * block_N, k * block_K], B_shared)
                    else:
                        T.copy(B[k * block_K, bx * block_N], B_shared)
                    T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
                T.copy(C_local, C[by * block_M, bx * block_N])

        return main

    def __post_init__(self):
        # Add Config Validation
        return
LeiWang1999 commented 1 month ago

Implement class BaseScheduler:

@dataclass
class BaseScheduler:

    enable_simplify: bool = True

    @staticmethod
    def Simplify(stmt: Union[PrimFunc, IRModule]):
        if isinstance(stmt, PrimFunc):
            return tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(stmt))["main"]
        elif isinstance(stmt, IRModule):
            return tvm.tir.transform.Simplify()(stmt)
        else:
            raise ValueError(f"Unsupported type: {type(stmt)}")

    def enable_simplify(self):
        self.enable_simplify = True
        return self

    def disable_simplify(self):
        self.enable_simplify = False
        return self

    def maybe_simplify(self, stmt: Union[PrimFunc, IRModule]):
        if self.enable_simplify:
            return self.Simplify(stmt)
        return stmt

To wrap common class methods.

matmul = MatmulScheduler(
        M=M,
        N=N,
        K=K,
        trans_A=trans_A,
        trans_B=trans_B,
        dtypeAB=dtypeAB,
        dtypeC=dtypeC,
        accum_dtype=accum_dtype,
    ).disable_simplify().with_default_config()

    simplified = MatmulScheduler.Simplify(matmul)

Before applying simplification:

@T.prim_func
def main(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), C: T.Buffer((128, 128), "float16")):
    # with T.block("root"):
    bx = T.launch_thread("blockIdx.x", 2)
    by = T.launch_thread("blockIdx.y", 2)
    v = T.launch_thread("threadIdx.x", 128)
    v_1 = T.launch_thread("threadIdx.y", 1)
    v_2 = T.launch_thread("threadIdx.z", 1)
    with T.block(""):
        T.reads(A[T.min(0, by * 64):T.min(0, by * 64) + (T.max(159, by * 64 + 63) + 1 - T.min(0, by * 64)), T.min(by, 0) * 64:T.min(by, 0) * 64 + (T.max(by * 64 + 31, 127) + 1 - T.min(by, 0) * 64)], B[T.min(bx, 0) * 64:T.min(bx, 0) * 64 + (T.max(bx * 64 + 63, 159) + 1 - T.min(bx, 0) * 64), T.min(0, bx * 64):T.min(0, bx * 64) + (T.max(127, bx * 64 + 31) + 1 - T.min(0, bx * 64))])
        T.writes(C[by * 64:by * 64 + 64, bx * 64:bx * 64 + 64])
        A_shared = T.alloc_buffer((64, 32), "float16", scope="shared.dyn")
        B_shared = T.alloc_buffer((64, 32), "float16", scope="shared.dyn")
        C_local = T.alloc_buffer((64, 64), "float16", scope="local.fragment")
        if T.bool(False):
            T.attr(None, "threadblock_swizzle_pattern", "tl::rasterization2DRow<10>")
            T.evaluate(0)
        T.fill(T.tvm_access_ptr(T.type_annotation("float16"), C_local.data, 0, 4096, 2), 0)
        for k in T.serial(4, annotations={"num_stages": 2}):
            if T.bool(False):
                T.copy(T.region(A[k * 32, by * 64], 1, 64, 32), T.region(A_shared[0, 0], 2, 64, 32))
            else:
                T.copy(T.region(A[by * 64, k * 32], 1, 64, 32), T.region(A_shared[0, 0], 2, 64, 32))
            if T.bool(True):
                T.copy(T.region(B[bx * 64, k * 32], 1, 64, 32), T.region(B_shared[0, 0], 2, 64, 32))
            else:
                T.copy(T.region(B[k * 32, bx * 64], 1, 64, 32), T.region(B_shared[0, 0], 2, 64, 32))
            T.gemm(T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), C_local.data, 0, 4096, 3), T.bool(False), T.bool(True), 64, 64, 32, 0)
        T.copy(T.region(C_local[0, 0], 1, 64, 64), T.region(C[by * 64, bx * 64], 2, 64, 64))

After

# from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), C: T.Buffer((128, 128), "float16")):
    # with T.block("root"):
    bx = T.launch_thread("blockIdx.x", 2)
    by = T.launch_thread("blockIdx.y", 2)
    v = T.launch_thread("threadIdx.x", 128)
    v_1 = T.launch_thread("threadIdx.y", 1)
    v_2 = T.launch_thread("threadIdx.z", 1)
    with T.block(""):
        T.reads(A[0:160, 0:128], B[0:160, 0:128])
        T.writes(C[by * 64:by * 64 + 64, bx * 64:bx * 64 + 64])
        A_shared = T.alloc_buffer((64, 32), "float16", scope="shared.dyn")
        B_shared = T.alloc_buffer((64, 32), "float16", scope="shared.dyn")
        C_local = T.alloc_buffer((64, 64), "float16", scope="local.fragment")
        T.fill(T.tvm_access_ptr(T.type_annotation("float16"), C_local.data, 0, 4096, 2), 0)
        for k in T.serial(4, annotations={"num_stages": 2}):
            T.copy(T.region(A[by * 64, k * 32], 1, 64, 32), T.region(A_shared[0, 0], 2, 64, 32))
            T.copy(T.region(B[bx * 64, k * 32], 1, 64, 32), T.region(B_shared[0, 0], 2, 64, 32))
            T.gemm(T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), C_local.data, 0, 4096, 3), T.bool(False), T.bool(True), 64, 64, 32, 0)
        T.copy(T.region(C_local[0, 0], 1, 64, 64), T.region(C[by * 64, bx * 64], 2, 64, 64))