Closed xinyazhang closed 3 months ago
I try to do some benchmarking with the attached test to check
Benchmark will run the default, math, flashattention and memory efficient test for same problem and then printout the results. It will do that for all gpu devices it detects one by one and also for the cpu.
flash_attention_dot_product_benchmark_py.txt
It would be great if somebody could double check that the benchmark does not have any problems. Also suggestions for other tests to run would be great. (This test takes only a little time to execute, so wondering should I run some much more heavier tests)
I compared 3 different pytorch/aotriton builds I made with the rocm sdk builder for gfx1102/AMD 7700S and gfx1103 1) rocm 6.1.2 with pytorch 2.3.1 2) rocm 6.1.2 based one with pytorch 2.4.1rc1 and this aotriton but without tuning data for gfx1102 and gfx1103 3) rocm 6.1.2 base one with pytorch 2.4.1rc1 and this aotriton compiled after copying the tuning data for gfx1102 and gfx1103
Results are great! See pictures/links to bar graphs below.
2) gfx1010, gfx1030, gfx1035, gfx1102, gfx1103 with tuned and non-tuned pt241rc1
3) more extensive resnet benchmark on different gpus
4) Text file results of tuned results for gfx1102 and gfx1103
links to benchmarks used (updated version):
sql tuning data copy commands: (Need to learn how to do the real tuning for each gpu separately instead of using copy-data)
Same as a pictures
Excuse me. I just built PyTorch from source, with the latest https://github.com/ROCm/aotriton/commit/aae63d6e4b184c38442f0c3775665af5700a1e27 commit of AOTriton, and used it in my SD:Next installation. It seemed to work but with a degraded performance.
The first one was tested last month, either the math impl or a Triton impl which I mentioned here.
The second one is the AOTriton impl, which I can confirm by adding some code in flash_api.hip
:
if (p_dropout == 114514) {
TORCH_CHECK(false, "flash_api.hip -> mha_fwd");
}
The third one is a random cross attention optimization.
EDIT: The second and third ones were benchmarked with --medvram
, which is why they are slower than the first one. However, that the AOTriton impl is slower than the math impl or other non-SDP optimizations doesn't change.
I am also getting some unexpected noises. Math vs AOT:
Is this an issue with my setup or WSL?
Here is how I build PyTorch:
$ # activate the venv for torch
$ git clone https://github.com/pytorch/pytorch
$ cd pytorch
$ git checkout v2.4.0
diff --git a/.ci/docker/aotriton_version.txt b/.ci/docker/aotriton_version.txt
index d13e9d756c9..8b702706692 100644
--- a/.ci/docker/aotriton_version.txt
+++ b/.ci/docker/aotriton_version.txt
@@ -1,5 +1,5 @@
0.6b
manylinux_2_17
rocm6
-04b5df8c8123f90cba3ede7e971e6fbc6040d506
+aae63d6e4b184c38442f0c3775665af5700a1e27
3db6ecbc915893ff967abd6e1b43bd5f54949868873be60dc802086c3863e648
diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake
index ec6f09b6053..f6b46c97919 100644
--- a/cmake/External/aotriton.cmake
+++ b/cmake/External/aotriton.cmake
@@ -19,6 +19,7 @@ if(NOT __AOTRITON_INCLUDED)
BINARY_DIR ${__AOTRITON_BUILD_DIR}
PREFIX ${__AOTRITON_INSTALL_DIR}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
+ -DTARGET_GPUS=Navi31
-DAOTRITON_COMPRESS_KERNEL=OFF
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON
#!/bin/sh
# specify navi31 for setup.py
export PYTORCH_ROCM_ARCH=gfx1100
export AMDGPU_TARGETS=gfx1100
# fix call to rocm_agent_enumerator in hipcc.pl
export HCC_AMDGPU_TARGET=gfx1100
export MAX_JOBS=8
echo 2.4.0 > version.txt
pip3 install cmake ninja
pip3 install -r requirements.txt
# hipify
python3 tools/amd_build/build_amd.py
# build wheel in the dist dir
python3 setup.py bdist_wheel
# install the wheel and then build torchvision v0.19.0
Finally, install the wheels of torch
and torchvision
in my SD:Next installation.
I tried the same thing yesterday but on 2.5.0. The SDPBackend.MATH
kernel is down to 1.7 it/s from 2.9 it/s, and while SDPBackend.EFFICIENT_ATTENTION
"worked" and produced valid images while saving VRAM, the miopen autotune took multiple minutes per batch/res and after all that only got up to 3.0 it/s compared to the CK FA's 3.8 it/s.
Extra cursed but if you want maximum throughput you switch between the CK FA and the possibly bugged AOT efficient attention with dim <128 to hit 4.0 it/s.
Naive patch for Navi on Torch master
diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake
index cb6080f5f7..e2ba77f61d 100644
--- a/cmake/External/aotriton.cmake
+++ b/cmake/External/aotriton.cmake
@@ -14,7 +14,7 @@ if(NOT __AOTRITON_INCLUDED)
list(GET __AOTRITON_CI_INFO 3 __AOTRITON_CI_COMMIT)
ExternalProject_Add(aotriton_external
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
- GIT_TAG ${__AOTRITON_CI_COMMIT}
+ GIT_TAG main
SOURCE_DIR ${__AOTRITON_SOURCE_DIR}
BINARY_DIR ${__AOTRITON_BUILD_DIR}
PREFIX ${__AOTRITON_INSTALL_DIR}
@@ -23,6 +23,7 @@ if(NOT __AOTRITON_INCLUDED)
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NO_SHARED=OFF
+ -DTARGET_GPUS=Navi31
# CONFIGURE_COMMAND ""
BUILD_COMMAND "" # No build, install command will repeat the build process due to problems in the build system.
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
So I modified @lamikr 's benchmark script:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
# copyright (C) Mika Laitio, lamikr@gmail.com
# scaled dot product attention benchmark based on the documentation at
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
device = 'cuda:0'
solver_name_arr=['Default', 'Math', 'Flash Attention', 'Memory Efficient Attention']
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt='f(*args, **kwargs)', globals={'args': args, 'kwargs': kwargs, 'f': f}
)
return t0.blocked_autorange().mean * 1e6
def benchmark_scaled_dot_product_attention_for_shape(BATCH, N_HEADS, N_CTX, D_HEAD):
print('')
print(f'======== BATCH={BATCH} N_HEADS={N_HEADS} N_CTX={N_CTX} D_HEAD={D_HEAD} ========')
query = torch.randn(BATCH, N_HEADS, N_CTX, D_HEAD, device=device, dtype=dtype)
key = torch.randn(BATCH, N_HEADS, N_CTX, D_HEAD, device=device, dtype=dtype)
value = torch.randn(BATCH, N_HEADS, N_CTX, D_HEAD, device=device, dtype=dtype)
print(f' {solver_name_arr[0]}:')
microseconds = benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f' {microseconds:.3f} ms')
best_backend = (solver_name_arr[0], microseconds)
# Lets explore the speed of each of the 3 implementations
from torch.nn.attention import SDPBackend, sdpa_kernel
print(f' {solver_name_arr[1]}:')
with sdpa_kernel(SDPBackend.MATH):
microseconds = benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f' {microseconds:.3f} ms')
best_backend = (solver_name_arr[1], microseconds) if microseconds < best_backend[1] else best_backend
print(f' {solver_name_arr[2]}:')
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
try:
microseconds = benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f' {microseconds:.3f} ms')
best_backend = (solver_name_arr[2], microseconds) if microseconds < best_backend[1] else best_backend
except RuntimeError:
print(f' {solver_name_arr[2]} is not supported. See warnings for reasons.')
print(f' {solver_name_arr[3]}:')
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
try:
microseconds = benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f' {microseconds:.3f} ms')
best_backend = (solver_name_arr[3], microseconds) if microseconds < best_backend[1] else best_backend
except RuntimeError:
print(f' {solver_name_arr[3]} is not supported. See warnings for reasons.')
print(f'======== BATCH={BATCH} N_HEADS={N_HEADS} N_CTX={N_CTX} D_HEAD={D_HEAD} ========')
print(f'======== BEST_BACKEND={best_backend[0]} RESULT={best_backend[1]:.3f} ms ========')
dtype = torch.float16
print('Benchmarking with Default, Math, Flash Attention and Memory Efficient Attention backends.')
print('PyTorch version:', torch.__version__)
print('ROCM HIP version', torch.version.hip)
# default shape from original benchmark
for BATCH in [32]:
for N_HEADS in [32]:
for N_CTX in [1024]:
for D_HEAD in [32]:
benchmark_scaled_dot_product_attention_for_shape(BATCH, N_HEADS, N_CTX, D_HEAD)
# typical shapes for stable diffusion are (2, 8, [77, 256, 1024, 4096, 9216], [40, 160])
for BATCH in [2]:
for N_HEADS in [8]:
for N_CTX in [77, 256, 1024, 4096, 9216]:
for D_HEAD in [40, 160]:
benchmark_scaled_dot_product_attention_for_shape(BATCH, N_HEADS, N_CTX, D_HEAD)
The outputs:
Benchmarking with Default, Math, Flash Attention and Memory Efficient Attention backends.
PyTorch version: 2.4.0+git34bc123
ROCM HIP version 6.1.40093-bd86f1708
======== BATCH=32 N_HEADS=32 N_CTX=1024 D_HEAD=32 ========
Default:
5408.553 ms
Math:
13893.611 ms
Flash Attention:
5429.846 ms
Memory Efficient Attention:
5439.184 ms
======== BATCH=32 N_HEADS=32 N_CTX=1024 D_HEAD=32 ========
======== BEST_BACKEND=Default RESULT=5408.553 ms ========
======== BATCH=2 N_HEADS=8 N_CTX=77 D_HEAD=40 ========
Default:
25.621 ms
Math:
35.253 ms
Flash Attention:
25.354 ms
Memory Efficient Attention:
25.230 ms
======== BATCH=2 N_HEADS=8 N_CTX=77 D_HEAD=40 ========
======== BEST_BACKEND=Memory Efficient Attention RESULT=25.230 ms ========
======== BATCH=2 N_HEADS=8 N_CTX=77 D_HEAD=160 ========
Default:
133.646 ms
Math:
34.364 ms
Flash Attention:
133.844 ms
Memory Efficient Attention:
133.974 ms
======== BATCH=2 N_HEADS=8 N_CTX=77 D_HEAD=160 ========
======== BEST_BACKEND=Math RESULT=34.364 ms ========
======== BATCH=2 N_HEADS=8 N_CTX=256 D_HEAD=40 ========
Default:
47.344 ms
Math:
33.814 ms
Flash Attention:
47.385 ms
Memory Efficient Attention:
47.409 ms
======== BATCH=2 N_HEADS=8 N_CTX=256 D_HEAD=40 ========
======== BEST_BACKEND=Math RESULT=33.814 ms ========
======== BATCH=2 N_HEADS=8 N_CTX=256 D_HEAD=160 ========
Default:
303.830 ms
Math:
41.744 ms
Flash Attention:
304.616 ms
Memory Efficient Attention:
303.509 ms
======== BATCH=2 N_HEADS=8 N_CTX=256 D_HEAD=160 ========
======== BEST_BACKEND=Math RESULT=41.744 ms ========
======== BATCH=2 N_HEADS=8 N_CTX=1024 D_HEAD=40 ========
Default:
438.977 ms
Math:
264.109 ms
Flash Attention:
434.273 ms
Memory Efficient Attention:
439.044 ms
======== BATCH=2 N_HEADS=8 N_CTX=1024 D_HEAD=40 ========
======== BEST_BACKEND=Math RESULT=264.109 ms ========
======== BATCH=2 N_HEADS=8 N_CTX=1024 D_HEAD=160 ========
Default:
2029.862 ms
Math:
352.168 ms
Flash Attention:
2023.372 ms
Memory Efficient Attention:
2029.077 ms
======== BATCH=2 N_HEADS=8 N_CTX=1024 D_HEAD=160 ========
======== BEST_BACKEND=Math RESULT=352.168 ms ========
======== BATCH=2 N_HEADS=8 N_CTX=4096 D_HEAD=40 ========
Default:
5204.252 ms
Math:
3589.673 ms
Flash Attention:
5163.242 ms
Memory Efficient Attention:
5200.344 ms
======== BATCH=2 N_HEADS=8 N_CTX=4096 D_HEAD=40 ========
======== BEST_BACKEND=Math RESULT=3589.673 ms ========
======== BATCH=2 N_HEADS=8 N_CTX=4096 D_HEAD=160 ========
Default:
24856.113 ms
Math:
5005.118 ms
Flash Attention:
24756.145 ms
Memory Efficient Attention:
24857.965 ms
======== BATCH=2 N_HEADS=8 N_CTX=4096 D_HEAD=160 ========
======== BEST_BACKEND=Math RESULT=5005.118 ms ========
======== BATCH=2 N_HEADS=8 N_CTX=9216 D_HEAD=40 ========
Default:
23981.586 ms
Math:
19545.756 ms
Flash Attention:
23998.982 ms
Memory Efficient Attention:
24107.721 ms
======== BATCH=2 N_HEADS=8 N_CTX=9216 D_HEAD=40 ========
======== BEST_BACKEND=Math RESULT=19545.756 ms ========
======== BATCH=2 N_HEADS=8 N_CTX=9216 D_HEAD=160 ========
Default:
120753.834 ms
Math:
23917.182 ms
Flash Attention:
119387.834 ms
Memory Efficient Attention:
122226.725 ms
======== BATCH=2 N_HEADS=8 N_CTX=9216 D_HEAD=160 ========
======== BEST_BACKEND=Math RESULT=23917.182 ms ========
UPDATE: I also inspected tuning_database.sqlite3
, where there is no record for arch = 'gfx1100' AND (`inputs$max_seqlen_q` > 1024 OR `inputs$max_seqlen_q` > 1024)
, which might explain the issue.
I am also getting some unexpected noises. Math vs AOT:
This is a known problem and will be fixed by https://github.com/pytorch/pytorch/pull/133331
I have been experimenting with AOTriton for a week and successfully got the tuning system working, as described in the documentation. However, when I finally built AOTriton tuned for larger seqlen
, it produced only noise.
Here is what I have done:
Building with -DAOTRITON_BUILD_FOR_TUNING=ON
compiles every possible kernel, resulting in over 80k variants for Navi31 alone. Additionally, building for torch.float32
takes significantly more time and may easily time out without -DAOTRITON_GPU_BUILD_TIMEOUT=0.0
.
AOTRITON_COMPRESS_KERNEL
should always be set to ON
. Otherwise, you may encounter a "relocation R_X86_64_PC32 out of range" error when linking the AOTriton library, indicating that the embedded kernel code is too large. If needed, run sudo apt install -y zstd libzstd-dev
on Ubuntu.
I disabled the building of torch.float32
, bwd*
, and debug*
kernels, as my focus was solely on improving forward performance. The build process took around 6 hours, followed by another 6 hours for tuning.
I used these commands to tune for larger seqlen
:
PYTHONPATH=cpptune_build/bindings/ python3 test/tune_flash.py --bias_type 0 --seqlen_q 2048 4096 8192 --json_file tuning_database.json
PYTHONPATH=cpptune_build/bindings/ python3 test/tune_flash.py --bias_type 0 --seqlen_k 2048 4096 8192 --json_file tuning_database.json
These commands generate a tuning_database.json
file with the tuning results. Afterward, you can run this command to update the tuning_database.sqlite3
:
cat tuning_database.json | python3 v2python/table_tool.py -k FLASH -f v2python/rules/tuning_database.sqlite3 --action rawjson
This will upsert the records, after which you can push the tuning_database.sqlite3
and use it for building in PyTorch.
However, these benchmark outputs indicate that the newly built kernels are broken:
======== BATCH=2 N_HEADS=8 N_CTX=4096 D_HEAD=40 dtype=torch.float16 ========
Default:
3793.005 ms
Math:
3572.682 ms
Flash Attention:
3865.911 ms
Memory Efficient Attention:
3861.553 ms
======== BATCH=2 N_HEADS=8 N_CTX=4096 D_HEAD=40 ========
======== BEST_BACKEND=Math RESULT=3572.682 ms ========
======== BATCH=2 N_HEADS=8 N_CTX=4096 D_HEAD=160 dtype=torch.float16 ========
Default:
8.849 ms
Math:
4791.290 ms
Flash Attention:
9.147 ms
Memory Efficient Attention:
8.294 ms
======== BATCH=2 N_HEADS=8 N_CTX=4096 D_HEAD=160 ========
======== BEST_BACKEND=Memory Efficient Attention RESULT=8.294 ms ========
======== BATCH=2 N_HEADS=8 N_CTX=9216 D_HEAD=40 dtype=torch.float16 ========
Default:
9.159 ms
Math:
19549.479 ms
Flash Attention:
8.850 ms
Memory Efficient Attention:
8.565 ms
======== BATCH=2 N_HEADS=8 N_CTX=9216 D_HEAD=40 ========
======== BEST_BACKEND=Memory Efficient Attention RESULT=8.565 ms ========
======== BATCH=2 N_HEADS=8 N_CTX=9216 D_HEAD=160 dtype=torch.float16 ========
Default:
8.907 ms
Math:
23881.084 ms
Flash Attention:
9.188 ms
Memory Efficient Attention:
8.642 ms
======== BATCH=2 N_HEADS=8 N_CTX=9216 D_HEAD=160 ========
======== BEST_BACKEND=Memory Efficient Attention RESULT=8.642 ms ========
Regarding the noise have you tried disabling the SDPBackend.EFFICIENT_ATTENTION
kernel? With that enabled the PixArt Sigma model gives noise
even with https://github.com/pytorch/pytorch/pull/133331 applied. ccmake build
shows USE_ROCM: ON
so the switch should be working?
Currently running torch 2.4.1-rc1 with a cleared ~/.config/miopen
cache and the following patches
Using
with torch.nn.attention.sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.MATH]):
main(...)
Pixart-Sigma successfully ran and produced a stable image with the expected memory savings. So far I haven't found other networks that need EFFICIENT_ATTENTION
disabled, not even the gargantuan FLUX.
One more thing I noticed is 2.4.1-rc1 doesn't have the severe performance regressions 2.5.0-main has. The FLASH_ATTENTION
backend is on-par with MATH
for SDXL, but still severely behind in DiT models like pixart/flux. Additionally, even on SDXL it's much slower than both the incomplete CK flash from https://github.com/ROCm/flash-attention@howiejay/navi_support and the experimental triton flash from https://github.com/ROCm/flash-attention@main_perf using a simple function branch for head_dim <= 128
@Beinsezii
Only this patch are applied to my PyTorch v2.4.0:
And some changes to cmake/External/aotriton.cmake
to use my repo of AOTriton.
I just recompiled several times and confirmed that my new tuning records are causing the noise issue in Stable DIffusion 1.5, with only SDPBackend.FLASH_ATTENTION
and SDPBackend.Math
enabled:
The image ends up with pure noise if any N_CTX > 4096
in the computation.
Test script:
import torch
for N_CTX in [1024, 4096, 4097]:
query = torch.randn(2, 8, N_CTX, 40, device='cuda:0', dtype=torch.float16)
key = torch.randn(2, 8, N_CTX, 40, device='cuda:0', dtype=torch.float16)
value = torch.randn(2, 8, N_CTX, 40, device='cuda:0', dtype=torch.float16)
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
r1 = torch.nn.functional.scaled_dot_product_attention(query=query, key=key, value=value)
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION]):
r2 = torch.nn.functional.scaled_dot_product_attention(query=query, key=key, value=value)
atol = 0.0125
rtol = 0
print(f'N_CTX={N_CTX} allclose={torch.allclose(r1, r2, atol=atol, rtol=rtol)}')
print(f'r1={r1[0][0][0][:4]}')
print(f'r2={r2[0][0][0][:4]}')
print('')
Output:
N_CTX=1024 allclose=True
r1=tensor([-0.0812, -0.0159, -0.0928, -0.0147], device='cuda:0',
dtype=torch.float16)
r2=tensor([-0.0812, -0.0159, -0.0929, -0.0147], device='cuda:0',
dtype=torch.float16)
N_CTX=4096 allclose=True
r1=tensor([-0.0171, -0.0166, -0.0460, 0.0142], device='cuda:0',
dtype=torch.float16)
r2=tensor([-0.0171, -0.0166, -0.0460, 0.0142], device='cuda:0',
dtype=torch.float16)
N_CTX=4097 allclose=False
r1=tensor([ 0.0519, 0.0227, 0.0270, -0.0340], device='cuda:0',
dtype=torch.float16)
r2=tensor([1.5318e-05, 1.5318e-05, 1.5318e-05, 1.5318e-05], device='cuda:0',
dtype=torch.float16)
This is my tuning_database.json
if you are interested:
I just recompiled several times and confirmed that my new tuning records are causing the noise issue in Stable DIffusion 1.5, with only
SDPBackend.FLASH_ATTENTION
andSDPBackend.Math
enabled
Yes, I just ran SD15 up to 2048x2048 which peaks at seqlen 36864 on the outer layers and both FLASH and EFFICIENT attentions were completely stable. PixArt is still the only model I managed to break with EFFICIENT attention enabled on the stock tunes. One of my kernels timed out during compilation, I wonder if that's related?
Interestingly enabling either FLASH or EFFICIENT drops my SD15 speed by well over 30%, similar to the DiT models. Maybe it's more that SDXL is unusually fast with the new kernels?
attention | speed @ 512 |
---|---|
Math | 19.1 it/s |
Efficient or Flash | 13.9 it/s |
Math + howiejay/navi_support |
23.8 it/s |
All SDPA + howiejay/navi_support |
21.6 it/s |
Losses apply to all resolutions, though the all SDPA + howiejay combo might be faster for big upscale jobs from not needing vae tiling.
Update: The philox branch was merged. Do I dare rebuild..?
The philox branch was merged. Do I dare rebuild..?
No idea if these changes optimize performance for Navi 31.
In fact, I am going to do some Frankenstein things by combining these to create a pip package:
No idea if these changes optimize performance for Navi 31.
I'll let it build while I'm at the store tomorrow. I'll increase the kernel build timeout too since in theory the ones torch pulls are already validated if I'm understanding aotriton's MO correctly.
In fact, I am going to do some Frankenstein things by combining these to create a pip package:
A lot of the autotunes in aotriton/tritonsrc can page fault or even reset your gpu. I didn't find it worth tinkering with personally. If you do, build triton from master to hopefully have new compiler fixes.
main_perf
is interesting. It does pretty well for diffusion and almost works for llama but sometimes it likes to barf out bad tokens. I don't think it's intended for end-use as its got lots of debug printouts lol. It's actually a lot faster than navi's torch sdpa as well. I think it's the closest thing to a truly cross-platform flash attention that exists right now. I haven't tried backwards yet but since it's triton it should work?
@Beinsezii
Previously I made AOTriton's tritonsrc
work in SD:Next by replacing existing Navi 31 hack with a custom adapter to employ Triton's auto tuning. The result is a bit slower than the Math impl, but AOTriton's impl is more complete than others written in Triton.
ROCm/flash-attention@main_perf can provide an widely used interface for kernels written in Triton, that is what I am interested in.
Regarding the performance, I think it's more of a Triton compiler work, but if ROCm/flash-attention@main_perf performs better, the AOTriton's might be able to achieve that too.
@Beinsezii
May I ask how you use ROCm/flash-attention@main_perf?
It looks blazing fast at first glance, but I notice there is a hardcoded layout (input_metadata.layout = "bshd"
) and I have to transpose(1, 2)
before replacing SDPA with it, then the performance becomes ordinary.
@evshiron
I believe the transpose is negligible. At one point I made a custom Diffusers attention processor and it didn't seem any faster than just monkey patching sdpa so I removed it.
If you checkout my quickif repo to 3f832df6fccb7488ad7ed203d1dcadd820548965 before I removed the attention processors there's a file containing a diffusers attention processor for Flash Attention that works with both the howiejay and triton branches. Quickdif has the sdpa monkey patch on a flag so you can easily compare speeds.
Back when I first found main_perf I built triton from source and it was something like 3.4 it/s at SDXL 1024. Whether it's changes to upstream triton or the kvpacked branch being merged to main_perf it does seem slower now.
I also tried using main_perf on llama at some point but it must have an auto tune for sequence length which causes a kernel fetch for every new token effectively making it unusable.
I ran a benchmark script:
#!/usr/bin/env python
# Copyright © 2023-2024 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT
import pytest
import torch
import torch.nn.backends
import torch.nn.backends.thnn
import triton
TEST_TRITON = False
TEST_FLASH = True
TEST_FLASH_TRITON = True
TEST_TORCH = True
TEST_TORCH_MATH = True
# BATCH, N_HEADS, D_HEAD = 4, 32, 64
# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
BATCH, N_HEADS, N_CTX, D_HEAD = 2, 8, 4096, 128
# BATCH, N_HEADS, N_CTX, D_HEAD = 32, 32, 1024, 32
# vary seq length for fixed head and batch=4
configs = []
for mode in ['fwd']:
# for causal in [False, True]:
for causal in [False]:
configs.append(triton.testing.Benchmark(
x_names=['N_CTX'],
# lower to allow torch sdpa to pass the benchmark
# x_vals=[2**i for i in range(10, 15)] if not TEST_TORCH_MATH else [2**i for i in range(8, 15)],
x_vals=[
77, 256,
1024,
4096, 8192,
9216,
],
line_arg='provider',
line_vals=(['triton'] if TEST_TRITON else []) + (['flash'] if TEST_FLASH else []) + (['flash-triton'] if TEST_FLASH_TRITON else []) + (['torch'] if TEST_TORCH else []) + (['torch-math'] if TEST_TORCH_MATH else []),
line_names=(['Triton'] if TEST_TRITON else []) + ([f'Flash'] if TEST_FLASH else []) + ([f'Flash Triton'] if TEST_FLASH_TRITON else []) + (['Torch'] if TEST_TORCH else []) + (['Torch Math'] if TEST_TORCH_MATH else []),
styles=[('red', '-'), ('orange', '-'), ('yellow', '-'), ('green', '-'), ('blue', '-'), ('indigo', '-'), ('violet', '-')],
ylabel='flops',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
'causal': causal,
})
)
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"):
print(f"{N_CTX=}")
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
split_kernel = False
requires_grad=True if mode == 'bwd' else False
# Bwd pass only supports causal=True right now
if mode == 'bwd':
split_kernel = True if causal else split_kernel
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
b = None
sm_scale = 1.3
return_encoded_softmax = False
autotune = True
return_autotune = True
fn = lambda: attention(q, k, v, b, causal, sm_scale, split_kernel, return_encoded_softmax, autotune, return_autotune)[0]
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
if provider == "flash":
from flash_attn import flash_attn_func
# transpose is needed because metadata.layout is set to bshd
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
fn = lambda: flash_attn_func(q, k, v, causal=causal).transpose(1, 2)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
if provider == "flash-triton":
from flash_attn_rocm import flash_attn_func
# transpose is needed because metadata.layout is set to bshd
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
fn = lambda: flash_attn_func(q, k, v, causal=causal).transpose(1, 2)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
if provider == "torch":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
b = None
sm_scale = 1.3
fn = lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=causal, scale=sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
if provider == "torch-math":
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
b = None
sm_scale = 1.3
fn = lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=causal, scale=sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
if mode == 'bwd':
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
return total_flops / ms * 1e-9
# only works on post-Ampere GPUs right now
bench_flash_attention.run(save_path='.', print_data=True)
RX 7900 XTX (WSL, Ubuntu 22.04, ROCm 6.1.3):
fused-attention-batch2-head8-d128-fwd-causal=False:
N_CTX Flash Flash Triton Torch Torch Math
0 77.0 0.589959 0.760981 0.464005 0.488695
1 256.0 5.817600 3.338110 4.426412 2.501373
2 1024.0 23.633466 4.769343 8.805251 6.957538
3 4096.0 31.725495 5.340588 10.019228 9.271866
4 8192.0 33.280974 5.433053 10.138188 9.397304
5 9216.0 29.864885 5.443405 10.197588 9.261768
MATH
onlyfused-attention-batch2-head8-d128-fwd-causal=False: N_CTX Flash Torch Torch Math 0 77.0 3.351398 3.307981 2.140681 1 256.0 17.674617 20.356401 14.038236 2 1024.0 55.155258 53.944516 29.788766 3 4096.0 66.690431 65.878503 31.758032 4 8192.0 70.919531 71.069738 29.774254 5 9216.0 70.319972 69.594797 29.259435
fused-attention-batch2-head8-d128-fwd-causal=False: N_CTX Flash Torch Torch Math 0 77.0 4.211498 4.174075 2.549500 1 256.0 23.667371 24.601858 17.446561 2 1024.0 73.363291 72.073351 44.259957 3 4096.0 129.738712 127.396387 53.866077 4 8192.0 146.490523 143.630641 54.301182 5 9216.0 136.064218 133.215636 53.720898
I ran a benchmark script: {snip} RX 7900 XTX (WSL, Ubuntu 22.04, ROCm 6.1.3):
fused-attention-batch2-head8-d128-fwd-causal=False: N_CTX Flash Flash Triton Torch Torch Math 0 77.0 0.589959 0.760981 0.464005 0.488695 1 256.0 5.817600 3.338110 4.426412 2.501373 2 1024.0 23.633466 4.769343 8.805251 6.957538 3 4096.0 31.725495 5.340588 10.019228 9.271866 4 8192.0 33.280974 5.433053 10.138188 9.397304 5 9216.0 29.864885 5.443405 10.197588 9.261768
Is that using a 2.5 nightly wheel? Those have severe performance regressions for me. On my patched 2.4 I got
fused-attention-batch2-head8-d128-fwd-causal=False: N_CTX Flash Torch Torch Math 0 77.0 1.048788 0.582226 1.585659 1 256.0 8.206998 4.299669 11.105464 2 1024.0 22.326422 6.281815 23.962240 3 4096.0 28.528275 8.420959 30.622303 4 8192.0 30.290357 8.520634 36.732278 5 9216.0 28.182249 8.543251 33.984900
Notice in particular the MATH backend.
@Beinsezii
Yes. The branch to merge has already updated version.txt
to 2.5.0a0
. But it's strange that it performs better that the CK-based one in your case, isn't it?
Yes. The branch to merge has already updated
version.txt
to2.5.0a0
. But it's strange that it performs better that the CK-based one in your case, isn't it?
Depending on the model that's actually true. I think it was PixArt or something that actually performs slightly better with math than with the ck flash ignoring the memory usage. SDXL is still 30-50% faster with CK flash though.
I'm pretty sure the navi flash was made with SDXL specifically in mind because that's what AMD's ck tune benchmarks were using at the time.
@Beinsezii
Have you tried https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal?
The performance drops to the level of the CK-based one after transposing. However, it might the only one that implements a backward pass while maintaining good performance I guess?
Have you tried https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal?
That one sketched me out because it's the first time I've ever gotten an extreme content warning on GitHub before accessing a repository, on GPU kernels of all things...
Right now I just use the howiejay CK flash with the aotriton 07 as a fallback for Diffusers. On the occasion I want to monkey with an LLM I just use llama.cpp's built-in FA which also supports ROCm. In theory the aotriton 07 kernels should work for exl2 but if the DiT performance is anything to go by it won't be fast by any measure.
Shame they dont' have discussions open anywhere so this info can be more accessible
@evshiron I bet most of the massive torch 2.5 performance regression is from https://github.com/pytorch/pytorch/pull/128922 So in theory it shouldn't affect the AOTriton kernels.
Hello @evshiron I am a novice in all of this and have seen that you seem to understand this a whole lot better than i am. I use mainly stable diffusion with my rx7900xtx and am wondering if you could tell me what version of flash attention would be the best for vram usage. I know that there is a ck, aotriton and wcmma version of flash attention but i dont understand the graphs so i dont know which would be best for my case. Could you please help me out that would be very kind.
@Beinsezii
I bet most of the massive torch 2.5 performance regression is from https://github.com/pytorch/pytorch/pull/128922
Glad to know! The fp32 performance is kind of poor, but thankfully AOTriton is not affected. I hope it's fixed when PyTorch 2.5.0 is released.
@Kademo15
I use mainly stable diffusion with my rx7900xtx
Generating images with Stable Diffusion doesn't require a backward pass implementation, so I will recommend CK-based ones:
pip install -U git+https://github.com/ROCm/flash-attention@navi_support
I bet most of the massive torch 2.5 performance regression is from pytorch/pytorch#128922
Glad to know! The fp32 performance is kind of poor, but thankfully AOTriton is not affected. I hope it's fixed when PyTorch 2.5.0 is released.
I think it does still affect the MiOpen tuning that happens automatically. First time using a resolution/batch size on torch 2.5/roc62 takes probably 3x as long compared to 2.4/roc61, even if the flash kernels are available. For large models or dimensions it's easily multiple minutes of overhead for me. You can see it yourself by clearing ~/.config/miopen/
or whatever the WSL equivalent is.
If it's not that PR I'm not sure what else it'd be. Nightly torch is a mess for ROCm right now.
...so I will recommend...
Wonder if it'd be worth consolidating the discussion everything relating to the different AOTriton, CK, and upstream Triton flash implementations somewhere with open threads like https://github.com/ROCm/ROCm/discussions/
I did a brief overview of CK flash a while ago at https://github.com/huggingface/diffusers/discussions/7172
Glad to know! The fp32 performance is kind of poor, but thankfully AOTriton is not affected. I hope it's fixed when PyTorch 2.5.0 is released.
@evshiron It's made it into 2.5.0-rc1. Maybe one of us should open an issue because won't that effect literally everyone with hardware f16 that's not NVIDIA?
@Beinsezii
I am currently using this branch:
Which is PyTorch 2.4.0 and has AOTriton updated. And I locally made a Flash Attention with different (currently two) backends.
For the same configuration we have tested, the results are:
fused-attention-batch2-head8-d128-fwd-causal=False:
N_CTX Triton Flash Attention Repeerc's Flash Attention Torch - (AOTriton?) Torch - Math
0 77.0 0.075390 0.696945 0.519582 1.089998
1 256.0 0.929930 11.258369 4.649597 9.788326
2 1024.0 7.258171 26.519371 8.881320 26.203581
3 4096.0 12.973244 33.084101 10.138325 34.008212
4 8192.0 13.660235 32.934404 10.160260 38.635985
5 9216.0 13.553775 33.048324 10.265818 33.515431
And the numbers for ROCm/flash-attention@howiejay/navi_support, taken from previous comment:
fused-attention-batch2-head8-d128-fwd-causal=False:
N_CTX Flash
0 77.0 0.589959
1 256.0 5.817600
2 1024.0 23.633466
3 4096.0 31.725495
4 8192.0 33.280974
5 9216.0 29.864885
@evshiron I wonder how they scale across different architectures? None of the implementations are super optimized so they swing wildly. Like for SDXL howiejay is the fastest by +20%, for SD15 the triton attention absolutely bombed by like -70%, for one of the DiT's (pixart or hunyuan?) I think Math was still the fastest as long as it didn't OOM. Eventually I just gave up and added CLI parameters to my app for runtime setting SDPA backend and adding flash attention monkey patches...
I really don't want to build torch from source a 12th time so I'll wait for Meta to figure out what they're doing with the 2.5 sdpa math casting before doing significant monkeying myself.
I don't know if you've seen but the main_perf
branch was cut into an upstream PR too. Looks like they cut backward again because it was slow, but seems like with further development Navi will be able to depend on the flash-attn
somewhat normally.
@Beinsezii
The performance varies between FA implementations. If you are seeking for best performance, routing to different backends based on shapes and parameters might be a good approach, which is planned for my Flash Attention library, but I don't know if it's worth it.
The biggest problem for Triton (including AOTriton) is its performance for AMD GPUs. Currently the performance of Triton matmul is about 70% of hipBLAS on RX 7900 XTX[^1], and Flash Attention performance is even worse[^2]. For CDNA GPUs like MI250/MI300, the performance of AOTriton is about 70% of the CK one too^3.
A year and a half after purchasing the RX 7900 XTX, I begin to wonder whether the performance improvement of Flash Attention on RX 7900 XTX could be as significant as it was on RTX 3090[^4] (or MI250 lol). Regardless, the VRAM usage does go down.
[^1]: You can run Triton's matmul tutorial locally [^2]: As we have tested for both Triton and AOTriton ones
[^4]: At the end of https://github.com/ROCm/aotriton/pull/39#issuecomment-2330669978
A year and a half after purchasing the RX 7900 XTX, I begin to wonder whether the performance improvement of Flash Attention on RX 7900 XTX could be as significant as it was on RTX 3090
When llama 3 came out one of my friends borked their venv and didn't have Flash Attention. On Exllama2 which compiles torch c extensions using the wheel bundled cuda/rocm compilers, my XTX actually outperformed his 3090. The GPUs themselves are completely capable it's just the software holding them back. It's a Sisyphean situation where all the cutting edge projects are built optimized to CUDA specifically so ROCm is perpetually in catch-up mode. I think that's why they're taking the approach of using Triton everywhere they can to reduce the amount of places they need to optimize kernels.
That said, yes it seems upstream Triton is particularly slow right now. I remember having a Triton flash running over 90% as fast as the CK branch in SDXL, but I can't appear to reproduce that right now. Maybe I was using the ROCm fork? It was definitely unstable though, I had to be really picky with the autotune configs to have it not cause GPU resets.
Edit: Navi configs might be worth looking at https://github.com/ROCm/triton/pull/640
@Beinsezii
The GPUs themselves are completely capable it's just the software holding them back.
Yeah. I believe the raw performance is comparable to RTX 3090 and RTX 4080, and we do have projects like MLC to achieve higher CPR than RTX 4090. In my opinion, LLM applications like llama.cpp (Ollama) is already satisfying on RX 7900 XTX. Diffusion models are more VRAM hungry and having a FA implementation with good performance would really benefit.
Maybe I was using the ROCm fork? It was definitely unstable though, I had to be really picky with the autotune configs to have it not cause GPU resets.
I am using the official repo of Triton and I believe most developments for AMD happen there now. I haven't experienced any auto-tuning crashes recently. Maybe it's a difference on WSL? Though the performance remains ordinary across different tuning configurations (including the mentioned ones).
@evshiron
Yeah. I believe the raw performance is at the same level as RTX 3090 and RTX 4080
The XTX should by every metric be faster than a 3090. I think the vector/ai/whatever flops was a bit over 3090 ti.
I am using the official repo of Triton and I believe most developments for AMD happen there now.
Out of curiosity I just built git+https://github.com/ROCm/triton.git@micmelesse/cache_fix#subdirectory=python
and gained +6% in for SDXL over upstream triton-lang using the flash branch from https://github.com/Dao-AILab/flash-attention/pull/1203. It ain't much but it's honest work. I think that puts it a bit above aotriton sdpa now?
I haven't experienced any auto-tuning crashes recently. Maybe it's a difference on WSL? Though the performance remains ordinary across different tuning configurations (including the mentioned ones).
I haven't had any resets recently either. This was quite a few months ago. Though the original triton.ops.flash_attention.flash_attn_func
still causes a page fault.
@Beinsezii
I think that puts it a bit above aotriton sdpa now?
It's hard to tell.
AOTriton is using an old & custom version of Triton for kernel generation, which might include some specified optimizations, while the official repo of Triton are updated frequently, and bumping the LLVM version might also affect the performance.
As you can see from previous benchmarks, few of those implementations are able to constantly win when shapes of inputs change.
AOTriton is using an old & custom version of Triton for kernel generation, which might include some specified optimizations, while the official repo of Triton are updated frequently, and bumping the LLVM version might also affect the performance.
To clarify, the customization of Triton is mainly about removing CUDA bits which only increases the download/build time.
What's Changed
cpptune
/cpp_tune
/cpptuning
) based on pre-compiling all GPU kernels with CMake optionAOTRITON_BUILD_FOR_TUNING
and kernel selection parameters provided by all AOTriton APIpkg-config
to search zstd sincefind_package(zstd)
is not supported officially.Known problems
This fixes #16