ROCm / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
141 stars 46 forks source link

[Issue]: Installation of flash-attention failed #88

Closed Kingmeng-Stack closed 2 weeks ago

Kingmeng-Stack commented 4 weeks ago

Problem Description

Problem Description

When trying to build Flash Attention from source on ROCm platform, the compilation fails with an invalid assembly instruction error. The specific error occurs in the bfloat16 implementation.

Environment

Error Details

The compilation fails with the following error:

In file included from .../fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2.hip:6:
[...truncated...]
/csrc/composable_kernel/include/ck_tile/core/numeric/bfloat16.hpp:170:21: error: invalid operand for instruction
    170 |     asm volatile("\n \
        |                     ^
<inline asm>:2:26: note: instantiated into assembly here
      2 |              v_cmp_u_f32 s[4:5], v0, v0

First tried to use the precompiled wheel, but it wasn't available:

Guessing wheel URL: https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+rocm62torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
Precompiled wheel not found. Building from source...

The build from source then failed with the assembly instruction error mentioned above.

  building 'flash_attn_2_cuda' extension
  creating /home/meng/workspace/flash-attention/build/temp.linux-x86_64-3.10
  creating /home/meng/workspace/flash-attention/build/temp.linux-x86_64-3.10/build
  creating /home/meng/workspace/flash-attention/build/temp.linux-x86_64-3.10/csrc
  creating /home/meng/workspace/flash-attention/build/temp.linux-x86_64-3.10/csrc/flash_attn_ck
  Emitting ninja build file /home/meng/workspace/flash-attention/build/temp.linux-x86_64-3.10/build.ninja...
  Compiling objects...
  Using envvar MAX_JOBS (23) as the number of workers...
  [1/2334] /opt/rocm-6.2.3/bin/hipcc  -I/home/meng/workspace/flash-attention/csrc/composable_kernel/include -I/home/meng/workspace/flash-attention/csrc/composable_kernel/library/include -I/home/meng/workspace/flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha -I/home/meng/workspace/python/lib/python3.10/site-packages/torch/include -I/home/meng/workspace/python/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/meng/workspace/python/lib/python3.10/site-packages/torch/include/TH -I/home/meng/workspace/python/lib/python3.10/site-packages/torch/include/THC -I/home/meng/workspace/python/lib/python3.10/site-packages/torch/include/THH -I/opt/rocm-6.2.3/include -I/home/meng/workspace/python/include -I/usr/include/python3.10 -c -c /home/meng/workspace/flash-attention/build/fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2.hip -o /home/meng/workspace/flash-attention/build/temp.linux-x86_64-3.10/build/fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 --offload-arch=native -O3 -std=c++17 -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero -DCK_ENABLE_BF16 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_FP8 -DCK_ENABLE_INT8 -DCK_USE_XDL -DUSE_PROF_API=1 -D__HIP_PLATFORM_HCC__=1 -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3 -fno-offload-uniform-block -mllvm -enable-post-misched=0 -mllvm -amdgpu-early-inline-all=true -mllvm -amdgpu-function-calls=false -mllvm -amdgpu-coerce-illegal-types=1 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flash_attn_2_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -fno-gpu-rdc
  FAILED: /home/meng/workspace/flash-attention/build/temp.linux-x86_64-3.10/build/fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2.o
  /opt/rocm-6.2.3/bin/hipcc  -I/home/meng/workspace/flash-attention/csrc/composable_kernel/include -I/home/meng/workspace/flash-attention/csrc/composable_kernel/library/include -I/home/meng/workspace/flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha -I/home/meng/workspace/python/lib/python3.10/site-packages/torch/include -I/home/meng/workspace/python/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/meng/workspace/python/lib/python3.10/site-packages/torch/include/TH -I/home/meng/workspace/python/lib/python3.10/site-packages/torch/include/THC -I/home/meng/workspace/python/lib/python3.10/site-packages/torch/include/THH -I/opt/rocm-6.2.3/include -I/home/meng/workspace/python/include -I/usr/include/python3.10 -c -c /home/meng/workspace/flash-attention/build/fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2.hip -o /home/meng/workspace/flash-attention/build/temp.linux-x86_64-3.10/build/fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 --offload-arch=native -O3 -std=c++17 -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero -DCK_ENABLE_BF16 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_FP8 -DCK_ENABLE_INT8 -DCK_USE_XDL -DUSE_PROF_API=1 -D__HIP_PLATFORM_HCC__=1 -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3 -fno-offload-uniform-block -mllvm -enable-post-misched=0 -mllvm -amdgpu-early-inline-all=true -mllvm -amdgpu-function-calls=false -mllvm -amdgpu-coerce-illegal-types=1 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flash_attn_2_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -fno-gpu-rdc
  In file included from /home/meng/workspace/flash-attention/build/fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2.hip:6:
  In file included from /home/meng/workspace/flash-attention/csrc/composable_kernel/example/ck_tile/01_fmha/fmha_bwd_hip.hpp:7:
  In file included from /home/meng/workspace/flash-attention/csrc/composable_kernel/include/ck_tile/core_hip.hpp:10:
  In file included from /home/meng/workspace/flash-attention/csrc/composable_kernel/include/ck_tile/core/arch/amd_buffer_addressing_hip.hpp:9:
  In file included from /home/meng/workspace/flash-attention/csrc/composable_kernel/include/ck_tile/core/numeric/vector_type_hip.hpp:13:
  /home/meng/workspace/flash-attention/csrc/composable_kernel/include/ck_tile/core/numeric/bfloat16.hpp:170:21: error: invalid operand for instruction
    170 |     asm volatile("\n \
        |                     ^
  <inline asm>:2:26: note: instantiated into assembly here
      2 |              v_cmp_u_f32 s[4:5], v0, v0
        |                          ^

download file error, The file does not exist. This is the URL of the file.


  Total number of replaced kernel launches: 1
  /home/meng/workspace/python/lib/python3.10/site-packages/setuptools/installer.py:27: SetuptoolsDeprecationWarning: setuptools.installer is deprecated. Requirements should be satisfied by a PEP 517 installer.
    warnings.warn(
  running bdist_wheel
  ---------------- https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+rocm62torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl     -------------- flash_attn-2.6.3+rocm62torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
  Guessing wheel URL:  https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+rocm62torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
  Precompiled wheel not found. Building from source...

### Operating System

"22.04.5 LTS (Jammy Jellyfish)"

### CPU

 AMD Ryzen 9 9900X 12-Core Processor

### GPU

AMD Radeon RX 7900 XTX

### ROCm Version

ROCm 6.2.3

### ROCm Component

_No response_

### Steps to Reproduce

1. git clone --recursive https://github.com/ROCm/flash-attention.git
2. cd flash-attention
3. MAX_JOBS=$((`nproc` - 1)) pip install -v 

### (Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

=====================
HSA System Attributes
=====================
Runtime Version:         1.1
Runtime Ext Version:     1.6
System Timestamp Freq.:  1000.000000MHz
Sig. Max Wait Duration:  18446744073709551615 (0xFFFFFFFFFFFFFFFF) (timestamp count)
Machine Model:           LARGE
System Endianness:       LITTLE
Mwaitx:                  DISABLED
DMAbuf Support:          NO

==========
HSA Agents
==========
*******
Agent 1
*******
  Name:                    CPU
  Uuid:                    CPU-XX
  Marketing Name:          CPU
  Vendor Name:             CPU
  Feature:                 None specified
  Profile:                 FULL_PROFILE
  Float Round Mode:        NEAR
  Max Queue Number:        0(0x0)
  Queue Min Size:          0(0x0)
  Queue Max Size:          0(0x0)
  Queue Type:              MULTI
  Node:                    0
  Device Type:             CPU
  Cache Info:
  Chip ID:                 0(0x0)
  Cacheline Size:          64(0x40)
  Internal Node ID:        0
  Compute Unit:            24
  SIMDs per CU:            0
  Shader Engines:          0
  Shader Arrs. per Eng.:   0
  Memory Properties:
  Features:                None
  Pool Info:
    Pool 1
      Segment:                 GLOBAL; FLAGS: KERNARG, FINE GRAINED
      Size:                    48094712(0x2ddddf8) KB
      Allocatable:             TRUE
      Alloc Granule:           4KB
      Alloc Recommended Granule:4KB
      Alloc Alignment:         4KB
      Accessible by all:       TRUE
    Pool 2
      Segment:                 GLOBAL; FLAGS: COARSE GRAINED
      Size:                    48094712(0x2ddddf8) KB
      Allocatable:             TRUE
      Alloc Granule:           4KB
      Alloc Recommended Granule:4KB
      Alloc Alignment:         4KB
      Accessible by all:       TRUE
  ISA Info:
*******
Agent 2
*******
  Name:                    gfx1100
  Marketing Name:          AMD Radeon RX 7900 XTX
  Vendor Name:             AMD
  Feature:                 KERNEL_DISPATCH
  Profile:                 BASE_PROFILE
  Float Round Mode:        NEAR
  Max Queue Number:        16(0x10)
  Queue Min Size:          4096(0x1000)
  Queue Max Size:          131072(0x20000)
  Queue Type:              MULTI
  Node:                    1
  Device Type:             GPU
  Cache Info:
    L1:                      32(0x20) KB
    L2:                      6144(0x1800) KB
    L3:                      98304(0x18000) KB
  Chip ID:                 29772(0x744c)
  Cacheline Size:          64(0x40)
  Max Clock Freq. (MHz):   2482
  Internal Node ID:        1
  Compute Unit:            96
  SIMDs per CU:            2
  Shader Engines:          6
  Shader Arrs. per Eng.:   2
  Coherent Host Access:    FALSE
  Memory Properties:
  Features:                KERNEL_DISPATCH
  Fast F16 Operation:      TRUE
  Wavefront Size:          32(0x20)
  Workgroup Max Size:      1024(0x400)
  Workgroup Max Size per Dimension:
    x                        1024(0x400)
    y                        1024(0x400)
    z                        1024(0x400)
  Max Waves Per CU:        32(0x20)
  Max Work-item Per CU:    1024(0x400)
  Grid Max Size:           4294967295(0xffffffff)
  Grid Max Size per Dimension:
    x                        4294967295(0xffffffff)
    y                        4294967295(0xffffffff)
    z                        4294967295(0xffffffff)
  Max fbarriers/Workgrp:   32
  Packet Processor uCode:: 2280
  SDMA engine uCode::      21
  IOMMU Support::          None
  Pool Info:
    Pool 1
      Segment:                 GLOBAL; FLAGS: COARSE GRAINED
      Size:                    25079976(0x17eb0a8) KB
      Allocatable:             TRUE
      Alloc Granule:           4KB
      Alloc Recommended Granule:2048KB
      Alloc Alignment:         4KB
      Accessible by all:       FALSE
    Pool 2
      Segment:                 GROUP
      Size:                    64(0x40) KB
      Allocatable:             FALSE
      Alloc Granule:           0KB
      Alloc Recommended Granule:0KB
      Alloc Alignment:         0KB
      Accessible by all:       FALSE
  ISA Info:
    ISA 1
      Name:                    amdgcn-amd-amdhsa--gfx1100
      Machine Models:          HSA_MACHINE_MODEL_LARGE
      Profiles:                HSA_PROFILE_BASE
      Default Rounding Mode:   NEAR
      Default Rounding Mode:   NEAR
      Fast f16:                TRUE
      Workgroup Max Size:      1024(0x400)
      Workgroup Max Size per Dimension:
        x                        1024(0x400)
        y                        1024(0x400)
        z                        1024(0x400)
      Grid Max Size:           4294967295(0xffffffff)
      Grid Max Size per Dimension:
        x                        4294967295(0xffffffff)
        y                        4294967295(0xffffffff)
        z                        4294967295(0xffffffff)
      FBarrier Max Size:       32
*** Done ***

### Additional Information

_No response_
evshiron commented 3 weeks ago

The main branch of this repo only works on CDNA 2/3 GPUs.

If you are looking for a memory-efficient attention implementation, you can already use SDPA in the latest PyTorch (experimental), which is powered by https://github.com/ROCm/AOTriton.

If you want an old Flash Attention implementation that works on RDNA 3 GPUs, try this branch and you can find discussions in some issues:

jamesxu2 commented 2 weeks ago

Hi @Kingmeng-Stack , support for FlashAttention was initially added for MI200 and MI300 (i.e. CDNA 2/3 accelerators), but the some of the underlying composable kernel backend includes CDNA specific assembly resulting in those error: invalid operand for instruction [...] warnings.

There's an effort to ship an RDNA-compatible Triton kernel backend as an alternative to CK so that we can support FA on your device, but it's a work in progress and currently attempting to be upstreamed in FlashAttention.

Thanks to @evshiron for chiming in. I did not have any luck using that navi_support branch, but I believe the AOTriton implementation for FA does work.

gowthamtupili commented 1 week ago

Hi @jamesxu2, I am trying to build FA2 using RDNA 3 GPU architecture, can you guide which files in the composable kernel have CDNA specific instructions, so that I will try to build fa2.

jamesxu2 commented 1 week ago

@gowthamtupili, you can see from the error in the original issue that this file /[...]/composable_kernel/include/ck_tile/core/numeric/bfloat16.hpp is named as having error: invalid operand for instruction. I am not sure how you might find all files that include inline assembly which contain code that doesn't comply with the RDNA ISA. Our current implementation of FA relies on ck_tile, a subcomponent of composable kernel, that is simply neither designed for nor tested with RDNA, and it would be a significant undertaking to make it work for RDNA, if not a rewrite.

Further to that, I'm not sure how you plan to build FA without that inline assembly, unless you're able to translate it yourself into some RDNA-supporting equivalent.

evshiron commented 1 week ago

@gowthamtupili

You can find relevant fused kernels here:

Examples with wmma are existing implementations to support Navi 3x GPUs in ROCm/flash-attention@howiejay/navi_support. The xdl variants are written for CDNA and far more complete. Please note that CK is a template library and it can be a pain for unseasoned developers.

There is another Flash Attention implementation written in rocWMMA which works on Navi 3x GPUs too:

gowthamtupili commented 5 days ago

Hi @jamesxu2 , @evshiron, Thank you for the inputs, I now have a much clearer path forward with building FA2 on RDNA 3. Thanks again for the support

githust66 commented 2 days ago

@gowthamtupili

You can find relevant fused kernels here:

Examples with wmma are existing implementations to support Navi 3x GPUs in ROCm/flash-attention@howiejay/navi_support. The xdl variants are written for CDNA and far more complete. Please note that CK is a template library and it can be a pain for unseasoned developers.

There is another Flash Attention implementation written in rocWMMA which works on Navi 3x GPUs too:

@evshiron Hi,Can I install both of these together?