Closed LeiWang1999 closed 1 month ago
Implement class BaseScheduler
:
@dataclass
class BaseScheduler:
enable_simplify: bool = True
@staticmethod
def Simplify(stmt: Union[PrimFunc, IRModule]):
if isinstance(stmt, PrimFunc):
return tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(stmt))["main"]
elif isinstance(stmt, IRModule):
return tvm.tir.transform.Simplify()(stmt)
else:
raise ValueError(f"Unsupported type: {type(stmt)}")
def enable_simplify(self):
self.enable_simplify = True
return self
def disable_simplify(self):
self.enable_simplify = False
return self
def maybe_simplify(self, stmt: Union[PrimFunc, IRModule]):
if self.enable_simplify:
return self.Simplify(stmt)
return stmt
To wrap common class methods.
matmul = MatmulScheduler(
M=M,
N=N,
K=K,
trans_A=trans_A,
trans_B=trans_B,
dtypeAB=dtypeAB,
dtypeC=dtypeC,
accum_dtype=accum_dtype,
).disable_simplify().with_default_config()
simplified = MatmulScheduler.Simplify(matmul)
Before applying simplification:
@T.prim_func
def main(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), C: T.Buffer((128, 128), "float16")):
# with T.block("root"):
bx = T.launch_thread("blockIdx.x", 2)
by = T.launch_thread("blockIdx.y", 2)
v = T.launch_thread("threadIdx.x", 128)
v_1 = T.launch_thread("threadIdx.y", 1)
v_2 = T.launch_thread("threadIdx.z", 1)
with T.block(""):
T.reads(A[T.min(0, by * 64):T.min(0, by * 64) + (T.max(159, by * 64 + 63) + 1 - T.min(0, by * 64)), T.min(by, 0) * 64:T.min(by, 0) * 64 + (T.max(by * 64 + 31, 127) + 1 - T.min(by, 0) * 64)], B[T.min(bx, 0) * 64:T.min(bx, 0) * 64 + (T.max(bx * 64 + 63, 159) + 1 - T.min(bx, 0) * 64), T.min(0, bx * 64):T.min(0, bx * 64) + (T.max(127, bx * 64 + 31) + 1 - T.min(0, bx * 64))])
T.writes(C[by * 64:by * 64 + 64, bx * 64:bx * 64 + 64])
A_shared = T.alloc_buffer((64, 32), "float16", scope="shared.dyn")
B_shared = T.alloc_buffer((64, 32), "float16", scope="shared.dyn")
C_local = T.alloc_buffer((64, 64), "float16", scope="local.fragment")
if T.bool(False):
T.attr(None, "threadblock_swizzle_pattern", "tl::rasterization2DRow<10>")
T.evaluate(0)
T.fill(T.tvm_access_ptr(T.type_annotation("float16"), C_local.data, 0, 4096, 2), 0)
for k in T.serial(4, annotations={"num_stages": 2}):
if T.bool(False):
T.copy(T.region(A[k * 32, by * 64], 1, 64, 32), T.region(A_shared[0, 0], 2, 64, 32))
else:
T.copy(T.region(A[by * 64, k * 32], 1, 64, 32), T.region(A_shared[0, 0], 2, 64, 32))
if T.bool(True):
T.copy(T.region(B[bx * 64, k * 32], 1, 64, 32), T.region(B_shared[0, 0], 2, 64, 32))
else:
T.copy(T.region(B[k * 32, bx * 64], 1, 64, 32), T.region(B_shared[0, 0], 2, 64, 32))
T.gemm(T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), C_local.data, 0, 4096, 3), T.bool(False), T.bool(True), 64, 64, 32, 0)
T.copy(T.region(C_local[0, 0], 1, 64, 64), T.region(C[by * 64, bx * 64], 2, 64, 64))
After
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), C: T.Buffer((128, 128), "float16")):
# with T.block("root"):
bx = T.launch_thread("blockIdx.x", 2)
by = T.launch_thread("blockIdx.y", 2)
v = T.launch_thread("threadIdx.x", 128)
v_1 = T.launch_thread("threadIdx.y", 1)
v_2 = T.launch_thread("threadIdx.z", 1)
with T.block(""):
T.reads(A[0:160, 0:128], B[0:160, 0:128])
T.writes(C[by * 64:by * 64 + 64, bx * 64:bx * 64 + 64])
A_shared = T.alloc_buffer((64, 32), "float16", scope="shared.dyn")
B_shared = T.alloc_buffer((64, 32), "float16", scope="shared.dyn")
C_local = T.alloc_buffer((64, 64), "float16", scope="local.fragment")
T.fill(T.tvm_access_ptr(T.type_annotation("float16"), C_local.data, 0, 4096, 2), 0)
for k in T.serial(4, annotations={"num_stages": 2}):
T.copy(T.region(A[by * 64, k * 32], 1, 64, 32), T.region(A_shared[0, 0], 2, 64, 32))
T.copy(T.region(B[bx * 64, k * 32], 1, 64, 32), T.region(B_shared[0, 0], 2, 64, 32))
T.gemm(T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), C_local.data, 0, 4096, 3), T.bool(False), T.bool(True), 64, 64, 32, 0)
T.copy(T.region(C_local[0, 0], 1, 64, 64), T.region(C[by * 64, bx * 64], 2, 64, 64))
This pull request introduces several significant updates and additions to the
bitblas
library, particularly focusing on matrix multiplication (matmul) operations. Key changes include the implementation of a new matmul function for dequantized weights, the addition of various matmul schedulers, and the inclusion of comprehensive testing for these schedulers.New Features and Implementations:
Matmul Function for Dequantized Weights:
matmul_blocked_weight_only
function inbitblas/ops/general_matmul/tilelang/dequantize/matmul_weight_only.py
to handle matmul operations with dequantized weights.Matmul Schedulers:
MatmulScheduler
,MatmulFineGrainScheduler
, andMatmulWeightPropagationScheduler
inbitblas/ops/general_matmul/tilelang/dense/__init__.py
for different matmul scheduling strategies.Testing Enhancements:
testing/python/operators/test_general_matmul_tilelang_kernel.py
to validate the correctness and performance of the new matmul schedulers.Code Quality Improvements:
License Additions:
bitblas/ops/general_matmul/tilelang/dequantize/__init__.py
andbitblas/ops/general_matmul/tilelang/dequantize/matmul_weight_only.py
. [1] [2]Code Refactoring:
testing/python/tilelang/test_tilelang_dequantize_gemm.py
to useT.alloc_local
instead ofT.alloc_fragment
. [1] [2]None
intesting/python/tilelang/test_tilelang_dequantize_gemm.py
.These changes collectively enhance the functionality, maintainability, and reliability of the
bitblas
library, particularly in the context of matrix multiplication operations