NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.53k stars 943 forks source link

[BUG] cutlass python slower 4x than default pytorch GEMM #1662

Closed OrenLeung closed 2 months ago

OrenLeung commented 2 months ago

Hi,

I am trying to benchmark the difference TFLOP/s between cutlass and cublas (through pytorch)

i am following the example way of calling a GEMM op from your python example link

Unfortunately I see that cutlass is only able to 321 TFLOP/s on fp8 vs 1296 TFLOP/s with cuBLAS. Do yall have anything suggestions on how to improve the performance? I have attached the reprod script below

Results

CUTLASS FP8 GEMM Average TFLOP/s: 321.6616572818387 TFLOP/s
torch._scaled_mm (cuBLAS) FP8 Average TFLOP/s: 1296.876406864292 TFLOP/s
CUTLASS BF16 GEMM Average TFLOP/s: 302.8308746739446 TFLOP/s
torch.matmul (cuBLAS) BF16 Average TFLOP/s: 764.9407720916588 TFLOP/s
Speed-up from using FP8 CUTLASS GEMM vs. FP8 torch._scaled_mm: 0.24802799679237134x
Speed-up from using BF16 CUTLASS GEMM vs. BF16 torch.matmul: 0.39588800299647x

Setup

Reprod Script

import torch
import cutlass
from triton.testing import do_bench

# Matrix dimensions
M, N, K = 8192, 8192, 8192

# Create input/output tensors in FP8
A_fp8, B_fp8 = [torch.randn((M, K)).to(torch.float8_e4m3fn).to("cuda") for _ in range(2)]
C_fp8, D_fp8 = [torch.zeros((M, N)).to(torch.float8_e4m3fn).to("cuda") for _ in range(2)]
B_fp8 = B_fp8.t()

# Create input/output tensors in BF16
A_bf16, B_bf16 = [torch.randn((M, K), dtype=torch.bfloat16).to("cuda") for _ in range(2)]
C_bf16, D_bf16 = [torch.zeros((M, N), dtype=torch.bfloat16).to("cuda") for _ in range(2)]
B_bf16 = B_bf16.t()

# FP8 CUTLASS GEMM plan
plan_fp8 = cutlass.op.Gemm(element=torch.float8_e4m3fn, element_C=torch.float32, element_accumulator=torch.float32,
                           layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.ColumnMajor,
                           layout_C=cutlass.LayoutType.ColumnMajor)

# BF16 CUTLASS GEMM plan
plan_bf16 = cutlass.op.Gemm(element=torch.bfloat16, element_C=torch.float32, element_accumulator=torch.float32,
                            layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.ColumnMajor,
                            layout_C=cutlass.LayoutType.ColumnMajor)

# Function to run the FP8 CUTLASS GEMM operation
def run_gemm_fp8():
    plan_fp8.run(A_fp8, B_fp8, C_fp8, D_fp8, print_module=False)

# Function to run the BF16 CUTLASS GEMM operation
def run_gemm_bf16():
    plan_bf16.run(A_bf16, B_bf16, C_bf16, D_bf16, print_module=False)

# Function to run the FP8 torch._scaled_mm operation
def run_scaled_mm_fp8():
    torch._scaled_mm(A_fp8, B_fp8)

# Function to run the BF16 torch.matmul operation
def run_matmul_bf16():
    torch.matmul(A_bf16, B_bf16)

# Number of floating-point operations for one GEMM operation
flops_per_iteration = 2 * M * N * K

# Benchmark the FP8 CUTLASS GEMM operation
cutlass_fp8_time_ms = do_bench(run_gemm_fp8, warmup=20, rep=20)
cutlass_fp8_time_s = cutlass_fp8_time_ms / 1000
cutlass_fp8_tflops = flops_per_iteration / cutlass_fp8_time_s / 1e12

print(f"CUTLASS FP8 GEMM Average TFLOP/s: {cutlass_fp8_tflops} TFLOP/s")

# Benchmark the FP8 torch._scaled_mm operation
scaled_mm_fp8_time_ms = do_bench(run_scaled_mm_fp8, warmup=20, rep=20)
scaled_mm_fp8_time_s = scaled_mm_fp8_time_ms / 1000
scaled_mm_fp8_tflops = flops_per_iteration / scaled_mm_fp8_time_s / 1e12

print(f"torch._scaled_mm (cuBLAS) FP8 Average TFLOP/s: {scaled_mm_fp8_tflops} TFLOP/s")

# Benchmark the BF16 CUTLASS GEMM operation
cutlass_bf16_time_ms = do_bench(run_gemm_bf16, warmup=20, rep=20)
cutlass_bf16_time_s = cutlass_bf16_time_ms / 1000
cutlass_bf16_tflops = flops_per_iteration / cutlass_bf16_time_s / 1e12

print(f"CUTLASS BF16 GEMM Average TFLOP/s: {cutlass_bf16_tflops} TFLOP/s")

# Benchmark the BF16 torch.matmul operation
matmul_bf16_time_ms = do_bench(run_matmul_bf16, warmup=20, rep=20)
matmul_bf16_time_s = matmul_bf16_time_ms / 1000
matmul_bf16_tflops = flops_per_iteration / matmul_bf16_time_s / 1e12

print(f"torch.matmul (cuBLAS) BF16 Average TFLOP/s: {matmul_bf16_tflops} TFLOP/s")

# Calculate the speed-up for FP8 CUTLASS vs. FP8 torch._scaled_mm
speed_up_fp8 = scaled_mm_fp8_time_s / cutlass_fp8_time_s
print(f"Speed-up from using FP8 CUTLASS GEMM vs. FP8 torch._scaled_mm: {speed_up_fp8}x")

# Calculate the speed-up for BF16 CUTLASS vs. BF16 torch.matmul
speed_up_bf16 = matmul_bf16_time_s / cutlass_bf16_time_s
print(f"Speed-up from using BF16 CUTLASS GEMM vs. BF16 torch.matmul: {speed_up_bf16}x")
thakkarV commented 2 months ago

I have no clue how the pyT integration works here, but CUTLASS is not a prebuilt kernel library like cuBLAS, nor does it have any heuristics. It requires an expert who knows what kernel will work best to instance that and run it. In lieu of this, you can also built a ton of kernels in the library and autotune them using the cutlass profiler, however, even in that case you are only going to build an extremely small subset of the millions of possible kernels CUTLASS supports. You are also likely to miss out on many tuning knobs like rasterization remapping we just released with 3.5.1. Out of the box comparisons of CUTLASS with anything else (cuBLAS, Triton) is not a straightforward thing and requires deep knowledge of GPU architecture, CUTLASS itself, and whatever you are comparing it in order to ensure a faithful comparision.

thakkarV commented 2 months ago

That said, @jackkosaian our python interface should certainly be picking a better default here I'd think? for

M, N, K = 8192, 8192, 8192

CUTLASS should be hitting >= 1.5 PFLOP/s

OrenLeung commented 2 months ago

Hi Vijay @thakkarV ,

I appreciate your quick reply.

I do appreciate the flexible of cutlass. For example, i am trying to run e5m2 by e5m2 which cublas does not support as it never used in ML.

In lieu of this, you can also built a ton of kernels in the library and autotune them using the cutlass profiler, however, even in that case you are only going to build an extremely small subset of the millions of possible kernels CUTLASS supports.

Thanks for the tip about the profiler, I will try using that to build INT8 and e5m2 kernels.