triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.7k stars 1.53k forks source link

Large performance regression for FP8 E4M3 GEMM with `triton==2.3` #3828

Open mgoin opened 4 months ago

mgoin commented 4 months ago

There is a very large performance regression (6x slower for [8192,8192]x[8192,8192]) when using Triton for matmuls with float8 e4m3 inputs, comparing 2.2.0 and 2.3.0.

We use Triton for our fused MoE implementation in vLLM and noticed this regression while upgrading pytorch (thanks for quickly detecting @pcmoritz) from 2.2.1 -> 2.3.0, which brought about an upgrade for Triton as well (2.2.0 -> 2.3.0).

This regression seems to go away if I use the latest nightly, but we are still stuck between very poor FP8 performance with Triton and using the latest stable PyTorch (which we would like to have for FP8 GEMM support on SM89). Is it possible this could be hotfixed?

Below I share my minimal reproduction using triton.ops.matmul on an H100:

Results:

> pip install triton==2.2 numpy torch
> python benchmark_fp8.py
Benchmarking [torch.Size([8192, 8192]), fp8e4nv] x [torch.Size([8192, 8192]), fp8e4nv]
Elapsed time for 100 iterations: 0.086390 seconds

> pip install triton==2.3
> python benchmark_fp8.py
Benchmarking [torch.Size([8192, 8192]), fp8e4nv] x [torch.Size([8192, 8192]), fp8e4nv]
Elapsed time for 100 iterations: 0.547999 seconds

> pip uninstall -y triton
> pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
> python benchmark_fp8.py
Benchmarking [torch.Size([8192, 8192]), fp8e4nv] x [torch.Size([8192, 8192]), fp8e4nv]
Elapsed time for 100 iterations: 0.088446 seconds

Benchmarking script:

import triton
import triton.ops
import triton.language as tl
import torch
import time

benchmark_iters = 100

# Create input matrices
A = torch.randn(8192, 8192, dtype=torch.float16, device='cuda')
B = torch.randn(8192, 8192, dtype=torch.float16, device='cuda')

# Quantize
A_fp8 = A.to(torch.float8_e4m3fn)
B_fp8 = B.to(torch.float8_e4m3fn).T

# Convert to triton float8 dtype
A_fp8 = triton.reinterpret(A_fp8, tl.float8e4nv)
B_fp8 = triton.reinterpret(B_fp8, tl.float8e4nv)

print(f"Benchmarking [{A_fp8.shape}, {A_fp8.dtype}] x [{B_fp8.shape}, {B_fp8.dtype}]")

# Warm up GPU
for _ in range(10):
    c = triton.ops.matmul(A_fp8, B_fp8)
torch.cuda.synchronize()

# Timing the matmul
start_time = time.time()
for _ in range(benchmark_iters):
    c = triton.ops.matmul(A_fp8, B_fp8)
torch.cuda.synchronize()
elapsed_time = time.time() - start_time

print(f"Elapsed time for {benchmark_iters} iterations: {elapsed_time:.6f} seconds")
atalman commented 4 months ago

cc @jansel @malfet @seemethere This looks like H100 specific error. I am getting this issue on A100:

  File "/home/atalman/miniconda3/envs/py310/lib/python3.10/site-packages/triton/runtime/jit.py", line 416, in run
    self.cache[device][key] = compile(
  File "/home/atalman/miniconda3/envs/py310/lib/python3.10/site-packages/triton/compiler/compiler.py", line 191, in compile
    module = src.make_ir(options)
  File "/home/atalman/miniconda3/envs/py310/lib/python3.10/site-packages/triton/compiler/compiler.py", line 117, in make_ir
    return ast_to_ttir(self.fn, self, options=options)
  File "/home/atalman/miniconda3/envs/py310/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1231, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 45:31:            a = tl.load(A)
            b = tl.load(B)
        else:
            k_remaining = K - k * (BLOCK_K * SPLIT_K)
            _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
            a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
            b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
        if AB_DTYPE:
            a = a.to(C.dtype.element_ty)
            b = b.to(C.dtype.element_ty)
        if fp8_fast_accum:
            acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
                               ^
AssertionError('Dot op does not support fp8e4nv on CUDA arch < 90')
mgoin commented 4 months ago

Hey @atalman both triton and torch only support FP8 GEMM on GPUs with hardware support for FP8 tensor cores. So, this is intended to only work on Hopper (H100) or Ada Lovelace (L4, L40, RTX 4000 series)

plotfi commented 4 months ago

It seems the change to maxNumImpreciseAcc from https://github.com/openai/triton/pull/2804 brings the run time for matmuls back to 2.2.x levels.

ThomasRaoux commented 4 months ago

It seems the change to maxNumImpreciseAcc from #2804 brings the run time for matmuls back to 2.2.x levels.

ah right this is because before that the accumulation was happening on a lower precision. To solve that you need to use the 3 source dot (acc = tl.dot(a, b, acc) instead of acc += tl.dot(a, b)) because the other representation suggests user wants a 32bits addition.

pcmoritz commented 3 months ago

Fixed in triton 2.3.1 now :)