ROCm / aotriton

Ahead of Time (AOT) Triton Math Library
MIT License
36 stars 13 forks source link

[Issue]: failed to run the tune_flash.py #32

Open jinsong-mao opened 3 months ago

jinsong-mao commented 3 months ago

Problem Description

Hi,
I can't run the tritonsrc/tune_flash.py to autotune the flash attention kernel with one specific problem size(just examples), the error message is like this: image

we don't know the pre-optimized database contains what kind of sequence length, heads, head_dim, and we may need to tune our own sequence length with this script, However, we can't run it with the above issue, am I understand correctly? can we finetune the kernels for different seqlen and heads.

thanks

Operating System

ubuntu 22

CPU

5900x

GPU

AMD Instinct MI300X

ROCm Version

ROCm 6.1.0

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

xinyazhang commented 3 months ago

This seems due to API changes in upstream Triton. It is recommended to use the bundled Triton to do the tuning since this is the actual compiler that generates the GPU code.

We are moving to the upstream Triton but not going to happen immediately due to known bugs.

jinsong-mao commented 3 months ago

Thanks @xinyazhang,

Is there any obvious performance gain after autotuning in our repo?

It is recommended to use the bundled Triton to do the tuning since this is the actual compiler that generates the GPU code.

My understand is that we need to setup some tune space like this https://github.com/michael-sandoval/aotriton/blob/b164a966c7eceac84cfda2c3719cbc3e5bcaa553/test/triton_attn_torch_function.py#L20, and then rebuild the project so that the build process will tune it automatically? please correct me if I am wrong.

If I want to tune the seqlen_q and head_dim, how should I to that?

Thanks

xinyazhang commented 3 months ago

My understand is that we need to setup some tune space

Correct, and the right file to edit the tune space is tritonsrc/attn_torch_function.py. (The file you pointed to is a legacy file for debugging).

and then rebuild the project so that the build process will tune it automatically?

Currently it is not the case, the tuning database v2python/rules/tuning_database.sqlite3 guides the compiling of tuned GPU kernels, as well as generation of autotuning dispatcher during the process. tritonsrc/tune_flash.py has to be run and update this database in order to get an updated build. (Pass -h to tritonsrc/tune_flash.py for its usage). The tuning process cannot be done during the build due to cross-compiling and noises of the environment.

Tuning against seqlen_q is possible, but head_dim is a constexpr that's intrinsic to the "family" of GPU kernels and cannot be tuned. However, the database already contains tuning information on seqlen_q and seqlen_k, and commonly you don't need to run the tritonsrc/tune_flash.py by youself, unless your data input layouts is much different than the tuning script's assumption (e.g., QKV packed).

The attached code is the generated autotuning dispatcher and hopefully will gain you some idea about how it works internally

// Copyright © 2023-2024 Advanced Micro Devices, Inc.
// SPDX-License-Identifier: MIT

// clang-format off
#define INCBIN_PREFIX g_aotriton_FAMILY_flash_KERNEL_attn_fwd_GPU_MI300X_
#define INCBIN_STYLE INCBIN_STYLE_SNAKE

#define mangle(x) g_aotriton_FAMILY_flash_KERNEL_attn_fwd_GPU_MI300X_ ## x ## _data
#define smangle(x) g_aotriton_FAMILY_flash_KERNEL_attn_fwd_GPU_MI300X_ ## x ## _size

#include "../shim.attn_fwd.h"
#include <aotriton/_internal/triton_kernel.h>
#include <incbin.h>
#include <iostream>

// ['Q', 'K', 'V', 'B', 'Out', 'encoded_softmax'] = *bf16:16 sm_scale = fp32 M = *fp32:16 ['stride_qz', 'stride_qh', 'stride_qm'] = u64:16 stride_qk = 1 ['stride_kz', 'stride_kh', 'stride_kn'] = u64:16 stride_kk = 1 ['stride_vz', 'stride_vh', 'stride_vk'] = u64:16 stride_vn = 1 ['stride_bz', 'stride_bh', 'stride_bm'] = u64:16 stride_bn = 1 ['stride_oz', 'stride_oh', 'stride_om'] = u64:16 stride_on = 1 ['seqlen_q', 'seqlen_k'] = i32 head_dim = u64 dropout_p = fp32 philox_seed = u64 philox_offset_base = u32 CAUSAL = False BLOCK_DMODEL = 128 ENABLE_DROPOUT = False RETURN_ENCODED_SOFTMAX = False PADDED_HEAD = False BIAS_TYPE = 0 ; BLOCK_M = 128 BLOCK_N = 64 pre_load_v = 1 ; num_warps=4 num_stages=1 waves_per_eu=1
#define CURRENT_ENTRY_PUBLIC Autotune_attn_fwd__A1__F208

INCBIN(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave1, "/home/xinyazha/aotriton/build/v2src/flash/gpu_kernel_image.attn_fwd/attn_fwd-Sig-F__^bf16@16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave1-Gpu-MI300X.hsaco.zst");
INCBIN(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave2, "/home/xinyazha/aotriton/build/v2src/flash/gpu_kernel_image.attn_fwd/attn_fwd-Sig-F__^bf16@16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave2-Gpu-MI300X.hsaco.zst");
INCBIN(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave1, "/home/xinyazha/aotriton/build/v2src/flash/gpu_kernel_image.attn_fwd/attn_fwd-Sig-F__^bf16@16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave1-Gpu-MI300X.hsaco.zst");
INCBIN(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave0, "/home/xinyazha/aotriton/build/v2src/flash/gpu_kernel_image.attn_fwd/attn_fwd-Sig-F__^bf16@16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave0-Gpu-MI300X.hsaco.zst");
INCBIN(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave0, "/home/xinyazha/aotriton/build/v2src/flash/gpu_kernel_image.attn_fwd/attn_fwd-Sig-F__^bf16@16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave0-Gpu-MI300X.hsaco.zst");

#ifndef NDEBUG
static const char* incbin_kernel_names[] = {
  "F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave1",
  "F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave2",
  "F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave1",
  "F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave0",
  "F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave0"
};;
#endif

namespace { // Anonymous namespace

struct PerfFields {
  int32_t BLOCK_M;
    int32_t BLOCK_N;
    bool pre_load_v;
};

PerfFields image_perf_list [] = {
    { .BLOCK_M = 128, .BLOCK_N = 64, .pre_load_v = 1 },
    { .BLOCK_M = 128, .BLOCK_N = 64, .pre_load_v = 0 },
    { .BLOCK_M = 128, .BLOCK_N = 64, .pre_load_v = 0 },
    { .BLOCK_M = 128, .BLOCK_N = 64, .pre_load_v = 1 },
    { .BLOCK_M = 128, .BLOCK_N = 64, .pre_load_v = 0 }
};

aotriton::TritonKernel image_list [] = {
    { mangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave1), smangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave1), { 256 , 1, 1 }, 34816 },
    { mangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave2), smangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave2), { 256 , 1, 1 }, 34816 },
    { mangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave1), smangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave1), { 256 , 1, 1 }, 34816 },
    { mangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave0), smangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave0), { 256 , 1, 1 }, 34816 },
    { mangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave0), smangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave0), { 256 , 1, 1 }, 34816 },
};

uint8_t lut[6][6] = {{0,1,2,3,2,0},
 {1,2,2,4,0,0},
 {2,2,1,0,0,0},
 {2,0,0,3,0,3},
 {3,3,3,0,0,3},
 {0,0,0,3,3,3}};

}; // End of anonymous namespace

namespace aotriton::v2::flash::autotune {

// using aotriton::v2::flash::AttnFwdParams;

void CURRENT_ENTRY_PUBLIC::operator()(AttnFwdParams& params) {
    auto seqlen_q_binned_index = [] (int x) {
        if (x <= 64) return 0;
        if (x <= 128) return 1;
        if (x <= 256) return 2;
        if (x <= 512) return 3;
        if (x <= 1024) return 4;
        if (x <= 2048) return 5;
        return 5;
    }(params.seqlen_q);
    auto seqlen_k_binned_index = [] (int x) {
        if (x <= 64) return 0;
        if (x <= 128) return 1;
        if (x <= 256) return 2;
        if (x <= 512) return 3;
        if (x <= 1024) return 4;
        if (x <= 2048) return 5;
        return 5;
    }(params.seqlen_k);
    auto kernel_index = lut[seqlen_q_binned_index][seqlen_k_binned_index];
    params.selected_kernel = &image_list[kernel_index];
#ifndef NDEBUG
    std::cerr << __FILE__ << " kernel_index = " << int(kernel_index) << std::endl;
    params._debug_kernel_name = incbin_kernel_names[kernel_index];
#endif
    const auto& perf = image_perf_list[kernel_index];
    params.BLOCK_M = perf.BLOCK_M;
    params.BLOCK_N = perf.BLOCK_N;
    params.pre_load_v = perf.pre_load_v;
}

#undef CURRENT_ENTRY_PUBLIC
#undef mangle
#undef smangle
}
jinsong-mao commented 3 months ago

Thanks @xinyazhang ,

I can use tritonsrc/tune_flash.py to autotune the database for larger sequence 4096->16384, and for head_dim-128, llama-7b used this head-dim. We can find some performance improvement for forward kernel, However, we can't build this project after autotune with "ninja install", there are some strange issues related to the luts (maybe in dispatcher ), we are not sure our flow is correct or should we do anything else to get higher performance? looks the GPU kernel binary in our experiment is the same as OOTB build.

By the way, looks we need to use scripts under tritonsrc to get better performance.

xinyazhang commented 3 months ago

We can find some performance improvement for forward kernel, However, we can't build this project after autotune with "ninja install",

Some error message would be helpful. However the most likely cause is the database got truncated. The updated database should have roughly identical size of the original one's.

jinsong-mao commented 3 months ago

Sure,

My autotune command is: python tritonsrc/tune_flash.py --db_file v2python/rules/tuning_database.sqlite3 --seqlen_q 1024 --dtype float16 --d_head 128

And the error message for ninja install is: image

Thanks

xinyazhang commented 3 months ago

It seems the HSACO kernel was not correctly compiled. You probably want to remove ~/.triton/cache and /tmp/amd_triton_kernel-* and try again.

jinsong-mao commented 3 months ago

It seems the HSACO kernel was not correctly compiled. You probably want to remove ~/.triton/cache and /tmp/amd_triton_kernel-* and try again.

looks we have new errors after applying above commands, after autotuning the database, and also removed the cache and /tmp/xxx, it's the same error no matter only building or cmake+building. image

could u have a look at it? Thanks

xinyazhang commented 3 months ago

could u have a look at it? Thanks

It seems your tuning database contains an entry with seqlen_q=16384, seqlen_k=2048 only for a subset of configurations. Can you post your tuning_database.sqlite3 so I can take a look at?

jinsong-mao commented 3 months ago

Sure, tuning_database.zip

Actually, there is no need to tune seqlen_q=16384, seqlen_k=2048, seqlen_q and seqlen_k should be equal length in our cases, looks we need to setup them manually in tune_flash.py.

jinsong-mao commented 3 months ago

I have tuned the database with equal seqlen_q and seqlen_k, but it gives me some similar issue like this: image

the database is here: tuning_database.zip

BTW, the latest repo after you checking the varlen can't run tune_flash.py with the following issue: image

Thanks

xinyazhang commented 3 months ago

@jinsong-mao I found the problem. For a pair of seqlen_q and seqlen_k the tuning database should have all configurations getting tuned. For example image You could see there are 288 entries (144 per arch) tuned for (2048, 2048) but only 4 entries available for the new tuning option.

It is possible to fix this but our whole team got reassigned for another emergency task and the ETA is hard to estimate. The solution without modifying the code is to re-run tune_flash.py with default options + desired seqlen_q and seqlen_k, and only compile for the arch you modified

jinsong-mao commented 3 months ago

@xinyazhang, I think it's better to fix it OOTB, looks the typical seqlen is larger than 2048 in most cases, and looks the default database only covers head_dim=64, the smallest llama2-7b has head_dim=128, the official flash attention(version 2.5.9) supports head_dim=256 now, which gives perf gain for larger models.