Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.46k stars 1.23k forks source link

v2.5.5 fails to build on NVIDIA Jetson AGX Orin (aarch64/arch=compute_87,code=sm_87) #860

Open ms1design opened 6 months ago

ms1design commented 6 months ago

Hi,

I noticed that the new version 2.5.5 breaks build from source on Jetson devices.

Previously working build script

git clone --depth=1 https://github.com/Dao-AILab/flash-attention
cd flash-attention
git apply flash-attn.diff
python3 setup.py install

flash-attn.diff - Git Patch used to enable the build from the source

diff --git a/setup.py b/setup.py
index 75c92cb..5971477 100644
--- a/setup.py
+++ b/setup.py
@@ -45,20 +45,38 @@ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE
 # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
 FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"

+def get_system() -> str:
+    """
+    Returns the system name as used in wheel filenames.
+    """
+    if platform.system() == "Windows":
+        return "win"
+    elif platform.system() == "Darwin":
+        mac_version = ".".join(platform.mac_ver()[0].split(".")[:1])
+        return f"macos_{mac_version}"
+    elif platform.system() == "Linux":
+        return "linux"
+    else:
+        raise ValueError("Unsupported system: {}".format(platform.system()))

-def get_platform():
+
+def get_arch():
     """
-    Returns the platform name as used in wheel filenames.
+    Returns the system name as used in wheel filenames.
     """
-    if sys.platform.startswith("linux"):
-        return "linux_x86_64"
-    elif sys.platform == "darwin":
-        mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
-        return f"macosx_{mac_version}_x86_64"
-    elif sys.platform == "win32":
-        return "win_amd64"
+    if platform.machine() == "x86_64":
+        return "x86_64"
+    elif platform.machine() == "arm64" or platform.machine() == "aarch64":
+        return "aarch64"
     else:
-        raise ValueError("Unsupported platform: {}".format(sys.platform))
+        raise ValueError("Unsupported arch: {}".format(platform.machine()))
+
+
+def get_platform() -> str:
+    """
+    Returns the platform name as used in wheel filenames.
+    """
+    return f"{get_system()}_{get_arch()}"

 def get_cuda_bare_metal_version(cuda_dir):
@@ -115,14 +133,11 @@ if not SKIP_CUDA_BUILD:
                 "FlashAttention is only supported on CUDA 11.6 and above.  "
                 "Note: make sure nvcc has a supported version by running nvcc -V."
             )
-    # cc_flag.append("-gencode")
-    # cc_flag.append("arch=compute_75,code=sm_75")
-    cc_flag.append("-gencode")
-    cc_flag.append("arch=compute_80,code=sm_80")
+
     if CUDA_HOME is not None:
         if bare_metal_version >= Version("11.8"):
             cc_flag.append("-gencode")
-            cc_flag.append("arch=compute_90,code=sm_90")
+            cc_flag.append("arch=compute_87,code=sm_87")

     # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
     # torch._C._GLIBCXX_USE_CXX11_ABI

Error when building v2.5.5 on Jetson

...
12:11:17 FAILED: /opt/flash-attention/build/temp.linux-aarch64-cpython-310/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.o 
12:11:17 /usr/local/cuda/bin/nvcc  -I/opt/flash-attention/csrc/flash_attn -I/opt/flash-attention/csrc/flash_attn/src -I/opt/flash-attention/csrc/cutlass/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/cuda/include -I/usr/include/python3.10 -c -c /opt/flash-attention/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu -o /opt/flash-attention/build/temp.linux-aarch64-cpython-310/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++17 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math -gencode arch=compute_87,code=sm_87 --threads 4 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=flash_attn_2_cuda -D_GLIBCXX_USE_CXX11_ABI=1 -ccbin gcc
12:11:17 Killed
...
12:16:23 ninja: build stopped: subcommand failed.
12:16:23 Traceback (most recent call last):
12:16:23   File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 2100, in _run_ninja_build
12:16:23     subprocess.run(
12:16:23   File "/usr/lib/python3.10/subprocess.py", line 526, in run
12:16:23     raise CalledProcessError(retcode, process.args,
12:16:23 subprocess.CalledProcessError: Command '['ninja', '-v', '-j', '12']' returned non-zero exit status 1.
12:16:23 
12:16:23 The above exception was the direct cause of the following exception:
12:16:23 
12:16:23 Traceback (most recent call last):
12:16:23   File "/opt/flash-attention/setup.py", line 302, in <module>
12:16:23     setup(
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/__init__.py", line 103, in setup
12:16:23     return distutils.core.setup(**attrs)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/core.py", line 185, in setup
12:16:23     return run_commands(dist)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/core.py", line 201, in run_commands
12:16:23     dist.run_commands()
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/dist.py", line 969, in run_commands
12:16:23     self.run_command(cmd)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/dist.py", line 963, in run_command
12:16:23     super().run_command(command)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/dist.py", line 988, in run_command
12:16:23     cmd_obj.run()
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/command/install.py", line 85, in run
12:16:23     self.do_egg_install()
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/command/install.py", line 137, in do_egg_install
12:16:23     self.run_command('bdist_egg')
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/cmd.py", line 318, in run_command
12:16:23     self.distribution.run_command(command)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/dist.py", line 963, in run_command
12:16:23     super().run_command(command)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/dist.py", line 988, in run_command
12:16:23     cmd_obj.run()
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/command/bdist_egg.py", line 167, in run
12:16:23     cmd = self.call_command('install_lib', warn_dir=0)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/command/bdist_egg.py", line 153, in call_command
12:16:23     self.run_command(cmdname)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/cmd.py", line 318, in run_command
12:16:23     self.distribution.run_command(command)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/dist.py", line 963, in run_command
12:16:23     super().run_command(command)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/dist.py", line 988, in run_command
12:16:23     cmd_obj.run()
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/command/install_lib.py", line 11, in run
12:16:23     self.build()
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/command/install_lib.py", line 111, in build
12:16:23     self.run_command('build_ext')
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/cmd.py", line 318, in run_command
12:16:23     self.distribution.run_command(command)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/dist.py", line 963, in run_command
12:16:23     super().run_command(command)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/dist.py", line 988, in run_command
12:16:23     cmd_obj.run()
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/command/build_ext.py", line 89, in run
12:16:23     _build_ext.run(self)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/Cython/Distutils/old_build_ext.py", line 186, in run
12:16:23     _build_ext.build_ext.run(self)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/command/build_ext.py", line 345, in run
12:16:23     self.build_extensions()
12:16:23   File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 873, in build_extensions
12:16:23     build_ext.build_extensions(self)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/Cython/Distutils/old_build_ext.py", line 195, in build_extensions
12:16:23     _build_ext.build_ext.build_extensions(self)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/command/build_ext.py", line 467, in build_extensions
12:16:23     self._build_extensions_serial()
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/command/build_ext.py", line 493, in _build_extensions_serial
12:16:23     self.build_extension(ext)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/command/build_ext.py", line 250, in build_extension
12:16:23     _build_ext.build_extension(self, ext)
12:16:23   File "/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/command/build_ext.py", line 548, in build_extension
12:16:23     objects = self.compiler.compile(
12:16:23   File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 686, in unix_wrap_ninja_compile
12:16:23     _write_ninja_file_and_compile_objects(
12:16:23   File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 1774, in _write_ninja_file_and_compile_objects
12:16:23     _run_ninja_build(
12:16:23   File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 2116, in _run_ninja_build
12:16:23     raise RuntimeError(message) from e
12:16:23 RuntimeError: Error compiling objects for extension
iamsiddhantsahu commented 2 months ago

@ms1design Thanks a lot for sharing this -- I was also trying to build the FlashAttention on the Jetson AGX Orin 64 GB. So I followed your steps and applied the diff that you specified. I am cloning it based off the current commit 74b0761ff7efc7b90d4e5aeb529c1b2a09a7458c but I am not able to built it. Can you share with which commit you were able to built it so that I can also git checkout that specific commit -- or even better if you could share the .whl file

git clone --depth=1 https://github.com/Dao-AILab/flash-attention
cd flash-attention
git apply flash-attn.diff
python3 setup.py install

By the way, how much time did you take you to build it? For me the built process seems to take forever.

iamsiddhantsahu commented 2 months ago

@ms1design How much time does it take for your build to finish? For me it never seems to finish and gets stuck at this point

Screenshot 2024-07-16 at 08 55 34

And which nvcc and CUDA version are you using?

ms1design commented 2 months ago

Hi @iamsiddhantsahu,

I'm a contributor of jetson-containers repository where you can find flash-attention docker container for Jetson devices. And yes it's available as well as pre-builded wheels if you switch to the jetson-containers repo.

Give it a try, you can also add your own containers and mix all of available (and in most cases pre-builded wheels or tarballs) AI/ML libraries for jetson devices.

As flash-attention is usually just a part of my builds there, but the latest successful build was using:

Lib Version
cuda 12.2
cudnn 8.9
tensorrt 8.6
python 3.10
pytorch 2.2

All of those above you can set before building your docker container in jetson-containers repo.

iamsiddhantsahu commented 2 months ago

Many thanks @ms1design for letting me know this -- yes, I am giving it a try right now.

I wanted to actually compare:

  1. TensorRT-LLM (https://github.com/dusty-nv/jetson-containers/tree/master/packages/llm/tensorrt_llm)
  2. QServe (https://github.com/mit-han-lab/qserve)

Which jetson-container would you recommend in order to install QServe?

ms1design commented 2 months ago

@iamsiddhantsahu just try to build the QServe from source on jetson from a new container. You can follow docs: https://github.com/dusty-nv/jetson-containers/blob/master/docs/packages.md

iamsiddhantsahu commented 2 months ago

@ms1design thanks for the suggestion -- yes will try that I think QServe depends on the FlashAttention and xFormers package that I was interested in flash-attention container https://github.com/dusty-nv/jetson-containers/tree/master/packages/llm/flash-attention