iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.59k stars 580 forks source link

Lack of warp tiling for CUDA GEMM codegen. #12194

Open JamesTheZ opened 1 year ago

JamesTheZ commented 1 year ago

I am studying the TensorCore GEMM codegen of IREE. I notice a big performance gap between IREE and cuBlas. For example, when [M, N, K] is [1024, 512, 1024], I use the following script to run GEMM:

The command is:
../iree-build/tools/iree-run-mlir \
  --iree-hal-target-backends=cuda \
  --iree-hal-cuda-llvm-target-arch=sm_86 \
  matmul.mlir \
  --input="1024x1024xf16=-1" --input="1024x512xf16=-1" --input="1024x512xf16=-1"

Content in the matmul.mlir is:
#compilation0 = #iree_codegen.compilation_info<
  lowering_config = <tile_sizes = [[128, 128, 32]]>,
  translation_info = <LLVMGPUMatmulTensorCore
  pipeline_depth = 3>,
  workgroup_size = [64 : index, 2 : index, 1 : index]>
func.func @matmul_1024x1024xf16_times_1024x512xf16_into_1024x512xf16_for_LLVMGPUMatmulTensorCore_128_128_32_64_2_1(%lhs: tensor<1024x1024xf16>, %rhs: tensor<1024x512xf16>, %acc: tensor<1024x512xf16>) -> tensor<1024x512xf16> {
  %result = linalg.matmul {compilation_info = #compilation0} ins(%lhs, %rhs: tensor<1024x1024xf16>, tensor<1024x512xf16>) outs(%acc: tensor<1024x512xf16>) -> tensor<1024x512xf16>
  return %result: tensor<1024x512xf16>
}

With Nsight Compute tool, the duration is 62us. While the cuBlas version only takes 30.5us. Is this the expected performance? (lowering_config = <tile_sizes = [[128, 128, 32]]> is the block-level tiling configuration, am I right? Are there some other tuning factors to speed up the GEMM codegen in IREE?)

I dig into the IREE passes about TensorCore GEMM codegen. I only find the block-level tiling. Is there warp-level, and even thread-level, tiling in the GEMM schedule like that in CUTLASS?

JamesTheZ commented 1 year ago

I just want to measure the best performance of IREE for TensorCore codegen on MatMul op. Is there a tool to tune and measure the performance of single op?

MaheshRavishankar commented 1 year ago

There is warp level tiling in IREE. But I am not clear what you are looking for. @ThomasRaoux can maybe provide more details