triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12k stars 1.43k forks source link

feat(hardware support): verify if hopper optimizations apply to ada lovelace (sm_89) #2192

Open jon-chuang opened 10 months ago

jon-chuang commented 10 months ago

Although server class GPU (A100, H100) are main target for production, some may have their own server with commercial GPU or want to develop optimizations on local GPU. Hence, we should attempt to support sm_89 for hopper-specific features. Currently they might be ignored.

Actually, NVIDIA docs seem pretty clear: image Seems that the only thing that hopper and ada lovelace (or other sm_89) share is support for fp8 tensor core, which is not listed explicitly.

Jokeren commented 10 months ago

wgmma current only supports sm_90 in the public doc, so the comment is OK IMHO.

jon-chuang commented 10 months ago

Cool, let me see if TMA will improve sm_89 (Ada lovelace) performance or result in any errors.

jon-chuang commented 10 months ago

I found some evidence on sm_89 (RTX 4070) that TMA and thread block cluster does not work:

ptxas /tmp/compile-ptx-src-959430, line 41; error   : Feature '%clusterid' requires .target sm_90 or higher
ptxas /tmp/compile-ptx-src-7ff6b4, line 76; error   : Feature '%cluster_ctaid' requires .target sm_90 or higher

Use of TMA results in a Python abort e.g.

test/unit/hopper/test_persistent_warp_specialized_gemm.py::test_user_defined_persistent_warp_specialized_gemm[2048-2048-64-64-64-16-1-False-True-True] Fatal Python error: Aborted

Current thread 0x00007efe63f8b000 (most recent call first):
  File "/home/jonch/Desktop/Programming/mlsys/triton/python/triton/compiler/compiler.py", line 49 in ttir_compute_capability_rewrite
  File "/home/jonch/Desktop/Programming/mlsys/triton/python/triton/compiler/compiler.py", line 55 in optimize_ttir
  File "/home/jonch/Desktop/Programming/mlsys/triton/python/triton/compiler/compiler.py", line 382 in <lambda>
  File "/home/jonch/Desktop/Programming/mlsys/triton/python/triton/compiler/compiler.py", line 488 in compile
  File "<string>", line 74 in static_persistent_tma_warp_specialized_matmul_kernel
  File "/home/jonch/Desktop/Programming/mlsys/triton/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py", line 438 in test_user_defined_persistent_warp_specialized_gemm