Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.43k stars 1.22k forks source link

Could not build wheels for flash-attn, which is required to install pyproject.toml-based projects #416

Open jackaihfia2334 opened 1 year ago

jackaihfia2334 commented 1 year ago
  ptxas info    : Compiling entry function '_Z25flash_bwd_dot_do_o_kernelILb1E23Flash_bwd_kernel_traitsILi64ELi128ELi128ELi8ELi4ELi4ELi4ELb0ELb0EN7cutlass10bfloat16_tE19Flash_kernel_traitsILi64ELi128ELi128ELi8ES2_EEEv16Flash_bwd_params' for 'sm_90'
  ptxas info    : Function properties for _Z25flash_bwd_dot_do_o_kernelILb1E23Flash_bwd_kernel_traitsILi64ELi128ELi128ELi8ELi4ELi4ELi4ELb0ELb0EN7cutlass10bfloat16_tE19Flash_kernel_traitsILi64ELi128ELi128ELi8ES2_EEEv16Flash_bwd_params
      0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
  ptxas info    : Used 34 registers
  ninja: build stopped: subcommand failed.
  Traceback (most recent call last):
    File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 1902, in _run_ninja_build
      subprocess.run(
    File "/usr/lib/python3.10/subprocess.py", line 524, in run
      raise CalledProcessError(retcode, process.args,
  subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

  The above exception was the direct cause of the following exception:

  Traceback (most recent call last):
    File "<string>", line 2, in <module>
    File "<pip-setuptools-caller>", line 34, in <module>
    File "/tmp/pip-install-0em76put/flash-attn_82b7e874dae44f0f854165b5859a6df5/setup.py", line 202, in <module>
      setup(
    File "/usr/local/lib/python3.10/dist-packages/setuptools/__init__.py", line 107, in setup
      return distutils.core.setup(**attrs)
    File "/usr/lib/python3.10/distutils/core.py", line 148, in setup
      dist.run_commands()
    File "/usr/lib/python3.10/distutils/dist.py", line 966, in run_commands
      self.run_command(cmd)
    File "/usr/local/lib/python3.10/dist-packages/setuptools/dist.py", line 1234, in run_command
      super().run_command(command)
    File "/usr/lib/python3.10/distutils/dist.py", line 985, in run_command
      cmd_obj.run()
    File "/usr/local/lib/python3.10/dist-packages/wheel/bdist_wheel.py", line 343, in run
      self.run_command("build")
    File "/usr/lib/python3.10/distutils/cmd.py", line 313, in run_command
      self.distribution.run_command(command)
    File "/usr/local/lib/python3.10/dist-packages/setuptools/dist.py", line 1234, in run_command
      super().run_command(command)
    File "/usr/lib/python3.10/distutils/dist.py", line 985, in run_command
      cmd_obj.run()
    File "/usr/lib/python3.10/distutils/command/build.py", line 135, in run
      self.run_command(cmd_name)
    File "/usr/lib/python3.10/distutils/cmd.py", line 313, in run_command
      self.distribution.run_command(command)
    File "/usr/local/lib/python3.10/dist-packages/setuptools/dist.py", line 1234, in run_command
      super().run_command(command)
    File "/usr/lib/python3.10/distutils/dist.py", line 985, in run_command
      cmd_obj.run()
    File "/usr/local/lib/python3.10/dist-packages/setuptools/command/build_ext.py", line 84, in run
      _build_ext.run(self)
    File "/usr/local/lib/python3.10/dist-packages/Cython/Distutils/old_build_ext.py", line 186, in run
      _build_ext.build_ext.run(self)
    File "/usr/lib/python3.10/distutils/command/build_ext.py", line 340, in run
      self.build_extensions()
    File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 848, in build_extensions
      build_ext.build_extensions(self)
    File "/usr/local/lib/python3.10/dist-packages/Cython/Distutils/old_build_ext.py", line 195, in build_extensions
      _build_ext.build_ext.build_extensions(self)
    File "/usr/lib/python3.10/distutils/command/build_ext.py", line 449, in build_extensions
      self._build_extensions_serial()
    File "/usr/lib/python3.10/distutils/command/build_ext.py", line 474, in _build_extensions_serial
      self.build_extension(ext)
    File "/usr/local/lib/python3.10/dist-packages/setuptools/command/build_ext.py", line 246, in build_extension
      _build_ext.build_extension(self, ext)
    File "/usr/lib/python3.10/distutils/command/build_ext.py", line 529, in build_extension
      objects = self.compiler.compile(sources,
    File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 661, in unix_wrap_ninja_compile
      _write_ninja_file_and_compile_objects(
    File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 1575, in _write_ninja_file_and_compile_objects
      _run_ninja_build(
    File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 1918, in _run_ninja_build
      raise RuntimeError(message) from e
  RuntimeError: Error compiling objects for extension
  [end of output]

note: This error originates from a subprocess, and is likely not a problem with pip. ERROR: Failed building wheel for flash-attn Running setup.py clean for flash-attn Failed to build flash-attn ERROR: Could not build wheels for flash-attn, which is required to install pyproject.toml-based projects

liuyijiang1994 commented 1 year ago

same

TS10armourer commented 1 year ago

same

mindkrypted commented 1 year ago

@jackaihfia2334

I was not able to install/build for Windows (I don't think it's supported yet). So, I decided to create a new environment with WSL2 to use Flash Attention for training LLM purposes.

I started by installing the Text Generation WebUI requirements first. Followed by Cuda 11.8, NVCC for Cuda 11.8

Then I started to try installing flash-attention without success, I was getting the same error message as your issue, shell log seemed a bit different tho.

Lot's of troubleshooting later I finally got it working. Not sure if it'll work for you, but this might help.

conda list -n textgen ``` # Name Version Build Channel _libgcc_mutex 0.1 main _openmp_mutex 5.1 1_gnu abseil-cpp 20211102.0 hd4dd3e8_0 absl-py 1.4.0 pypi_0 pypi accelerate 0.21.0 pypi_0 pypi aiobotocore 2.5.0 py310h06a4308_0 aiofiles 22.1.0 py310h06a4308_0 aiohttp 3.8.3 py310h5eee18b_0 aioitertools 0.7.1 pyhd3eb1b0_0 aiosignal 1.2.0 pyhd3eb1b0_0 aiosqlite 0.18.0 py310h06a4308_0 alabaster 0.7.12 pyhd3eb1b0_0 altair 5.0.1 pypi_0 pypi anaconda 2023.07 py310_1 anyio 3.5.0 py310h06a4308_0 appdirs 1.4.4 pyhd3eb1b0_0 argon2-cffi 21.3.0 pyhd3eb1b0_0 argon2-cffi-bindings 21.2.0 py310h7f8727e_0 arrow 1.2.3 py310h06a4308_1 arrow-cpp 11.0.0 py310h7516544_0 astroid 2.14.2 py310h06a4308_0 astropy 5.1 py310ha9d4c09_0 asttokens 2.0.5 pyhd3eb1b0_0 async-timeout 4.0.2 py310h06a4308_0 atomicwrites 1.4.0 py_0 attrs 22.1.0 py310h06a4308_0 auto-gptq 0.3.0+cu117 pypi_0 pypi automat 20.2.0 py_0 autopep8 1.6.0 pyhd3eb1b0_1 aws-c-common 0.4.57 he6710b0_1 aws-c-event-stream 0.1.6 h2531618_5 aws-checksums 0.1.9 he6710b0_0 aws-sdk-cpp 1.8.185 hce553d0_0 babel 2.11.0 py310h06a4308_0 backcall 0.2.0 pyhd3eb1b0_0 bcrypt 3.2.0 py310h5eee18b_1 beautifulsoup4 4.12.2 py310h06a4308_0 binaryornot 0.4.4 pyhd3eb1b0_1 bitsandbytes 0.41.1 pypi_0 pypi black 23.3.0 py310h06a4308_0 blas 1.0 mkl bleach 4.1.0 pyhd3eb1b0_0 blosc 1.21.3 h6a678d5_0 bokeh 3.2.1 py310h2f386ee_0 boost-cpp 1.73.0 h7f8727e_12 botocore 1.29.76 py310h06a4308_0 bottleneck 1.3.5 py310ha9d4c09_0 brotli 1.0.9 h5eee18b_7 brotli-bin 1.0.9 h5eee18b_7 brotlipy 0.7.0 py310h7f8727e_1002 brunsli 0.1 h2531618_0 bzip2 1.0.8 h7b6447c_0 c-ares 1.19.0 h5eee18b_0 c-blosc2 2.8.0 h6a678d5_0 ca-certificates 2023.05.30 h06a4308_0 cachetools 5.3.1 pypi_0 pypi certifi 2023.7.22 py310h06a4308_0 cffi 1.15.1 py310h5eee18b_3 cfitsio 3.470 h5893167_7 chardet 4.0.0 py310h06a4308_1003 charls 2.2.0 h2531618_0 charset-normalizer 2.0.4 pyhd3eb1b0_0 clang 9.0.0 default_hde54327_4 conda-forge clang-tools 9.0.0 default_hde54327_4 conda-forge clangxx 9.0.0 default_h6bfbf51_4 conda-forge click 8.0.4 py310h06a4308_0 cloudpickle 2.2.1 py310h06a4308_0 cmake 3.25.0 pypi_0 pypi colorama 0.4.6 py310h06a4308_0 colorcet 3.0.1 py310h06a4308_0 comm 0.1.2 py310h06a4308_0 constantly 15.1.0 py310h06a4308_0 contourpy 1.0.5 py310hdb19cb5_0 cookiecutter 1.7.3 pyhd3eb1b0_0 cryptography 41.0.2 py310h774aba0_0 cssselect 1.1.0 pyhd3eb1b0_0 cuda-nvcc 11.8.89 0 nvidia/label/cuda-11.8.0 curl 8.1.1 h37d81fd_2 cycler 0.11.0 pyhd3eb1b0_0 cytoolz 0.12.0 py310h5eee18b_0 daal4py 2023.1.1 py310h3c18c91_0 dal 2023.1.1 hdb19cb5_48679 dask 2023.6.0 py310h06a4308_0 dask-core 2023.6.0 py310h06a4308_0 datasets 2.12.0 py310h06a4308_0 datashader 0.15.1 py310h06a4308_0 datashape 0.5.4 py310h06a4308_1 dbus 1.13.18 hb2f20db_0 debugpy 1.6.7 py310h6a678d5_0 decorator 5.1.1 pyhd3eb1b0_0 defusedxml 0.7.1 pyhd3eb1b0_0 diff-match-patch 20200713 pyhd3eb1b0_0 dill 0.3.6 py310h06a4308_0 diskcache 5.6.1 pypi_0 pypi distributed 2023.6.0 py310h06a4308_0 docker-pycreds 0.4.0 pypi_0 pypi docstring-to-markdown 0.11 py310h06a4308_0 docutils 0.18.1 py310h06a4308_3 einops 0.6.1 pypi_0 pypi entrypoints 0.4 py310h06a4308_0 et_xmlfile 1.1.0 py310h06a4308_0 exceptiongroup 1.0.4 py310h06a4308_0 executing 0.8.3 pyhd3eb1b0_0 exllama 0.0.10+cu117 pypi_0 pypi expat 2.4.9 h6a678d5_0 fastapi 0.95.2 pypi_0 pypi ffmpy 0.3.1 pypi_0 pypi filelock 3.9.0 py310h06a4308_0 flake8 6.0.0 py310h06a4308_0 flash-attn 2.0.4 pypi_0 pypi flask 2.2.2 py310h06a4308_0 fontconfig 2.14.1 h52c9d5c_1 fonttools 4.25.0 pyhd3eb1b0_0 freetype 2.12.1 h4a9f257_0 frozenlist 1.3.3 py310h5eee18b_0 fsspec 2023.4.0 py310h06a4308_0 gensim 4.3.0 py310h1128e8f_0 gflags 2.2.2 he6710b0_0 giflib 5.2.1 h5eee18b_3 gitdb 4.0.10 pypi_0 pypi gitpython 3.1.32 pypi_0 pypi glib 2.69.1 he621ea3_2 glog 0.5.0 h2531618_0 gmp 6.2.1 h295c915_3 gmpy2 2.1.2 py310heeb90bb_0 google-auth 2.22.0 pypi_0 pypi google-auth-oauthlib 1.0.0 pypi_0 pypi gradio 3.33.1 pypi_0 pypi gradio-client 0.2.5 pypi_0 pypi greenlet 2.0.1 py310h6a678d5_0 grpc-cpp 1.46.1 h33aed49_1 grpcio 1.56.2 pypi_0 pypi gst-plugins-base 1.14.1 h6a678d5_1 gstreamer 1.14.1 h5eee18b_1 h11 0.14.0 pypi_0 pypi h5py 3.7.0 py310he06866b_0 hdf5 1.10.6 h3ffc7dd_1 heapdict 1.0.1 pyhd3eb1b0_0 holoviews 1.17.0 py310h06a4308_0 httpcore 0.17.3 pypi_0 pypi httpx 0.24.1 pypi_0 pypi huggingface_hub 0.15.1 py310h06a4308_0 hvplot 0.8.4 py310h06a4308_0 hyperlink 21.0.0 pyhd3eb1b0_0 icu 58.2 he6710b0_3 idna 3.4 py310h06a4308_0 imagecodecs 2021.8.26 py310h46e8fbd_2 imageio 2.31.1 py310h06a4308_0 imagesize 1.4.1 py310h06a4308_0 imbalanced-learn 0.10.1 py310h06a4308_1 importlib-metadata 6.0.0 py310h06a4308_0 importlib_metadata 6.0.0 hd3eb1b0_0 incremental 21.3.0 pyhd3eb1b0_0 inflection 0.5.1 py310h06a4308_0 iniconfig 1.1.1 pyhd3eb1b0_0 intake 0.6.8 py310h06a4308_0 intel-openmp 2023.1.0 hdb19cb5_46305 intervaltree 3.1.0 pyhd3eb1b0_0 ipykernel 6.19.2 py310h2f386ee_0 ipython 8.12.0 py310h06a4308_0 ipython_genutils 0.2.0 pyhd3eb1b0_1 ipywidgets 8.0.4 py310h06a4308_0 isort 5.9.3 pyhd3eb1b0_0 itemadapter 0.3.0 pyhd3eb1b0_0 itemloaders 1.0.4 pyhd3eb1b0_1 itsdangerous 2.0.1 pyhd3eb1b0_0 jaraco.classes 3.2.1 pyhd3eb1b0_0 jedi 0.18.1 py310h06a4308_1 jeepney 0.7.1 pyhd3eb1b0_0 jellyfish 0.9.0 py310h7f8727e_0 jinja2 3.1.2 py310h06a4308_0 jinja2-time 0.2.0 pyhd3eb1b0_3 jmespath 0.10.0 pyhd3eb1b0_0 joblib 1.2.0 py310h06a4308_0 jpeg 9e h5eee18b_1 jq 1.6 h27cfd23_1000 json5 0.9.6 pyhd3eb1b0_0 jsonschema 4.17.3 py310h06a4308_0 jupyter 1.0.0 py310h06a4308_8 jupyter_client 7.4.9 py310h06a4308_0 jupyter_console 6.6.3 py310h06a4308_0 jupyter_core 5.3.0 py310h06a4308_0 jupyter_events 0.6.3 py310h06a4308_0 jupyter_server 1.23.4 py310h06a4308_0 jupyter_server_fileid 0.9.0 py310h06a4308_0 jupyter_server_ydoc 0.8.0 py310h06a4308_1 jupyter_ydoc 0.2.4 py310h06a4308_0 jupyterlab 3.6.3 py310h06a4308_0 jupyterlab_pygments 0.1.2 py_0 jupyterlab_server 2.22.0 py310h06a4308_0 jupyterlab_widgets 3.0.5 py310h06a4308_0 jxrlib 1.1 h7b6447c_2 keyring 23.13.1 py310h06a4308_0 kiwisolver 1.4.4 py310h6a678d5_0 krb5 1.20.1 h568e23c_1 lazy-object-proxy 1.6.0 py310h7f8727e_0 lazy_loader 0.2 py310h06a4308_0 lcms2 2.12 h3be6417_0 ld_impl_linux-64 2.38 h1181459_1 lerc 3.0 h295c915_0 libaec 1.0.4 he6710b0_1 libboost 1.73.0 h28710b8_12 libbrotlicommon 1.0.9 h5eee18b_7 libbrotlidec 1.0.9 h5eee18b_7 libbrotlienc 1.0.9 h5eee18b_7 libclang 10.0.1 default_hb85057a_2 libcurl 8.1.1 h91b91d3_2 libdeflate 1.17 h5eee18b_0 libedit 3.1.20221030 h5eee18b_0 libev 4.33 h7f8727e_1 libevent 2.1.12 h8f2d780_0 libffi 3.4.4 h6a678d5_0 libgcc-ng 11.2.0 h1234567_1 libgfortran-ng 11.2.0 h00389a5_1 libgfortran5 11.2.0 h1234567_1 libgomp 11.2.0 h1234567_1 libllvm10 10.0.1 hbcb73fb_5 libllvm14 14.0.6 hdb19cb5_3 libllvm9 9.0.1 default_hc23dcda_4 conda-forge libnghttp2 1.52.0 ha637b67_1 libpng 1.6.39 h5eee18b_0 libpq 12.15 h37d81fd_1 libprotobuf 3.20.3 he621ea3_0 libsodium 1.0.18 h7b6447c_0 libspatialindex 1.9.3 h2531618_0 libssh2 1.10.0 h37d81fd_2 libstdcxx-ng 11.2.0 h1234567_1 libthrift 0.15.0 h0d84882_2 libtiff 4.5.0 h6a678d5_2 libuuid 1.41.5 h5eee18b_0 libwebp 1.2.4 h11a3e52_1 libwebp-base 1.2.4 h5eee18b_1 libxcb 1.15 h7f8727e_0 libxkbcommon 1.0.1 hfa300c1_0 libxml2 2.9.14 h74e7548_0 libxslt 1.1.35 h4e12654_0 libzopfli 1.0.3 he6710b0_0 linkify-it-py 2.0.0 py310h06a4308_0 lit 15.0.7 pypi_0 pypi llama-cpp-python 0.1.77 pypi_0 pypi llama-cpp-python-cuda 0.1.77+cu117 pypi_0 pypi llvmlite 0.40.0 py310he621ea3_0 locket 1.0.0 py310h06a4308_0 lxml 4.9.1 py310h1edc446_0 lz4 4.3.2 py310h5eee18b_0 lz4-c 1.9.4 h6a678d5_0 lzo 2.10 h7b6447c_2 markdown 3.4.1 py310h06a4308_0 markdown-it-py 2.2.0 py310h06a4308_1 markupsafe 2.1.1 py310h7f8727e_0 matplotlib 3.7.1 py310h06a4308_1 matplotlib-base 3.7.1 py310h1128e8f_1 matplotlib-inline 0.1.6 py310h06a4308_0 mccabe 0.7.0 pyhd3eb1b0_0 mdit-py-plugins 0.3.0 py310h06a4308_0 mdurl 0.1.0 py310h06a4308_0 mistune 0.8.4 py310h7f8727e_1000 mkl 2023.1.0 h6d00ec8_46342 mkl-service 2.4.0 py310h5eee18b_1 mkl_fft 1.3.6 py310h1128e8f_1 mkl_random 1.2.2 py310h1128e8f_1 more-itertools 8.12.0 pyhd3eb1b0_0 mpc 1.1.0 h10f8cd9_1 mpfr 4.0.2 hb69a4c5_1 mpi 1.0 mpich mpich 4.1.1 hbae89fd_0 mpmath 1.3.0 py310h06a4308_0 msgpack-python 1.0.3 py310hd09550d_0 multidict 6.0.2 py310h5eee18b_0 multipledispatch 0.6.0 py310h06a4308_0 multiprocess 0.70.14 py310h06a4308_0 munkres 1.1.4 py_0 mypy_extensions 0.4.3 py310h06a4308_0 nbclassic 0.5.5 py310h06a4308_0 nbclient 0.5.13 py310h06a4308_0 nbconvert 6.5.4 py310h06a4308_0 nbformat 5.7.0 py310h06a4308_0 ncurses 6.4 h6a678d5_0 nest-asyncio 1.5.6 py310h06a4308_0 networkx 3.1 py310h06a4308_0 ninja 1.11.1 pypi_0 pypi ninja-base 1.10.2 hd09550d_5 nltk 3.8.1 py310h06a4308_0 notebook 6.5.4 py310h06a4308_1 notebook-shim 0.2.2 py310h06a4308_0 nspr 4.35 h6a678d5_0 nss 3.89.1 h6a678d5_0 numba 0.57.0 py310h1128e8f_0 numexpr 2.8.4 py310h85018f9_1 numpy 1.24.3 py310h5f9d8c6_1 numpy-base 1.24.3 py310hb5e798b_1 numpydoc 1.5.0 py310h06a4308_0 oauthlib 3.2.2 pypi_0 pypi oniguruma 6.9.7.1 h27cfd23_0 opencv-python 4.8.0.74 pypi_0 pypi openjpeg 2.4.0 h3ad879b_0 openpyxl 3.0.10 py310h5eee18b_0 openssl 1.1.1u h7f8727e_0 orc 1.7.4 hb3bc3d3_1 orjson 3.9.3 pypi_0 pypi packaging 23.0 py310h06a4308_0 pandas 1.5.3 py310h1128e8f_0 pandocfilters 1.5.0 pyhd3eb1b0_0 panel 1.2.1 py310h06a4308_0 param 1.13.0 py310h06a4308_0 parsel 1.6.0 py310h06a4308_0 parso 0.8.3 pyhd3eb1b0_0 partd 1.2.0 pyhd3eb1b0_1 pathspec 0.10.3 py310h06a4308_0 pathtools 0.1.2 pypi_0 pypi patsy 0.5.3 py310h06a4308_0 pcre 8.45 h295c915_0 peft 0.5.0.dev0 pypi_0 pypi pep8 1.7.1 py310h06a4308_1 pexpect 4.8.0 pyhd3eb1b0_3 pickleshare 0.7.5 pyhd3eb1b0_1003 pillow 10.0.0 pypi_0 pypi pip 23.2.1 py310h06a4308_0 platformdirs 2.5.2 py310h06a4308_0 plotly 5.9.0 py310h06a4308_0 pluggy 1.0.0 py310h06a4308_1 ply 3.11 py310h06a4308_0 pooch 1.4.0 pyhd3eb1b0_0 poyo 0.5.0 pyhd3eb1b0_0 prometheus_client 0.14.1 py310h06a4308_0 prompt-toolkit 3.0.36 py310h06a4308_0 prompt_toolkit 3.0.36 hd3eb1b0_0 protego 0.1.16 py_0 protobuf 4.23.4 pypi_0 pypi psutil 5.9.0 py310h5eee18b_0 ptyprocess 0.7.0 pyhd3eb1b0_2 pure_eval 0.2.2 pyhd3eb1b0_0 py-cpuinfo 8.0.0 pyhd3eb1b0_1 pyarrow 11.0.0 py310h468efa6_0 pyasn1 0.4.8 pyhd3eb1b0_0 pyasn1-modules 0.2.8 py_0 pycodestyle 2.10.0 py310h06a4308_0 pycparser 2.21 pyhd3eb1b0_0 pyct 0.5.0 py310h06a4308_0 pycurl 7.45.2 py310h37d81fd_0 pydantic 1.10.12 pypi_0 pypi pydispatcher 2.0.5 py310h06a4308_2 pydocstyle 6.3.0 py310h06a4308_0 pydub 0.25.1 pypi_0 pypi pyerfa 2.0.0 py310h7f8727e_0 pyflakes 3.0.1 py310h06a4308_0 pygments 2.15.1 py310h06a4308_1 pylint 2.16.2 py310h06a4308_0 pylint-venv 2.3.0 py310h06a4308_0 pyls-spyder 0.4.0 pyhd3eb1b0_0 pyodbc 4.0.34 py310h6a678d5_0 pyopenssl 23.2.0 py310h06a4308_0 pyparsing 3.0.9 py310h06a4308_0 pyqt 5.15.7 py310h6a678d5_1 pyqt5-sip 12.11.0 pypi_0 pypi pyqtwebengine 5.15.7 py310h6a678d5_1 pyre-extensions 0.0.29 pypi_0 pypi pyrsistent 0.18.0 py310h7f8727e_0 pysocks 1.7.1 py310h06a4308_0 pytables 3.8.0 py310h43249b6_2 pytest 7.4.0 py310h06a4308_0 python 3.10.12 h7a1cb2a_0 python-dateutil 2.8.2 pyhd3eb1b0_0 python-fastjsonschema 2.16.2 py310h06a4308_0 python-json-logger 2.0.7 py310h06a4308_0 python-lmdb 1.4.1 py310h6a678d5_0 python-lsp-black 1.2.1 py310h06a4308_0 python-lsp-jsonrpc 1.0.0 pyhd3eb1b0_0 python-lsp-server 1.7.2 py310h06a4308_0 python-multipart 0.0.6 pypi_0 pypi python-slugify 5.0.2 pyhd3eb1b0_0 python-snappy 0.6.1 py310h6a678d5_0 python-xxhash 2.0.2 py310h5eee18b_1 pytoolconfig 1.2.5 py310h06a4308_1 pytz 2022.7 py310h06a4308_0 pyviz_comms 2.3.0 py310h06a4308_0 pywavelets 1.4.1 py310h5eee18b_0 pyxdg 0.27 pyhd3eb1b0_0 pyyaml 6.0 py310h5eee18b_1 pyzmq 23.2.0 py310h6a678d5_0 qdarkstyle 3.0.2 pyhd3eb1b0_0 qstylizer 0.2.2 py310h06a4308_0 qt-main 5.15.2 h327a75a_7 qt-webengine 5.15.9 hd2b0992_4 qtawesome 1.2.2 py310h06a4308_0 qtconsole 5.4.2 py310h06a4308_0 qtpy 2.2.0 py310h06a4308_0 qtwebkit 5.212 h4eab89a_4 queuelib 1.5.0 py310h06a4308_0 re2 2022.04.01 h295c915_0 readline 8.2 h5eee18b_0 regex 2022.7.9 py310h5eee18b_0 requests 2.31.0 py310h06a4308_0 requests-file 1.5.1 pyhd3eb1b0_0 requests-oauthlib 1.3.1 pypi_0 pypi responses 0.13.3 pyhd3eb1b0_0 rfc3339-validator 0.1.4 py310h06a4308_0 rfc3986-validator 0.1.1 py310h06a4308_0 rope 1.7.0 py310h06a4308_0 rouge 1.0.1 pypi_0 pypi rsa 4.9 pypi_0 pypi rtree 1.0.1 py310h06a4308_0 s3fs 2023.4.0 py310h06a4308_0 sacremoses 0.0.43 pyhd3eb1b0_0 safetensors 0.3.1 pypi_0 pypi scikit-image 0.20.0 py310h6a678d5_0 scikit-learn 1.3.0 py310h1128e8f_0 scikit-learn-intelex 2023.1.1 py310h06a4308_0 scipy 1.10.1 py310h5f9d8c6_1 scrapy 2.8.0 py310h06a4308_0 seaborn 0.12.2 py310h06a4308_0 secretstorage 3.3.1 py310h06a4308_1 semantic-version 2.10.0 pypi_0 pypi send2trash 1.8.0 pyhd3eb1b0_1 sentencepiece 0.1.99 pypi_0 pypi sentry-sdk 1.29.2 pypi_0 pypi service_identity 18.1.0 pyhd3eb1b0_1 setproctitle 1.3.2 pypi_0 pypi setuptools 68.0.0 py310h06a4308_0 sip 6.6.2 py310h6a678d5_0 six 1.16.0 pyhd3eb1b0_1 smart_open 5.2.1 py310h06a4308_0 smmap 5.0.0 pypi_0 pypi snappy 1.1.9 h295c915_0 sniffio 1.2.0 py310h06a4308_1 snowballstemmer 2.2.0 pyhd3eb1b0_0 sortedcontainers 2.4.0 pyhd3eb1b0_0 soupsieve 2.4 py310h06a4308_0 sphinx 5.0.2 py310h06a4308_0 sphinxcontrib-applehelp 1.0.2 pyhd3eb1b0_0 sphinxcontrib-devhelp 1.0.2 pyhd3eb1b0_0 sphinxcontrib-htmlhelp 2.0.0 pyhd3eb1b0_0 sphinxcontrib-jsmath 1.0.1 pyhd3eb1b0_0 sphinxcontrib-qthelp 1.0.3 pyhd3eb1b0_0 sphinxcontrib-serializinghtml 1.1.5 pyhd3eb1b0_0 spyder 5.4.3 py310h06a4308_1 spyder-kernels 2.4.3 py310h06a4308_0 sqlalchemy 1.4.39 py310h5eee18b_0 sqlite 3.41.2 h5eee18b_0 stack_data 0.2.0 pyhd3eb1b0_0 starlette 0.27.0 pypi_0 pypi statsmodels 0.14.0 py310ha9d4c09_0 sympy 1.11.1 py310h06a4308_0 tabulate 0.8.10 py310h06a4308_0 tbb 2021.8.0 hdb19cb5_0 tbb4py 2021.8.0 py310hdb19cb5_0 tblib 1.7.0 pyhd3eb1b0_0 tenacity 8.2.2 py310h06a4308_0 tensorboard 2.13.0 pypi_0 pypi tensorboard-data-server 0.7.1 pypi_0 pypi terminado 0.17.1 py310h06a4308_0 text-unidecode 1.3 pyhd3eb1b0_0 textdistance 4.2.1 pyhd3eb1b0_0 threadpoolctl 2.2.0 pyh0d69192_0 three-merge 0.1.1 pyhd3eb1b0_0 tifffile 2021.7.2 pyhd3eb1b0_2 tinycss2 1.2.1 py310h06a4308_0 tk 8.6.12 h1ccaba5_0 tldextract 3.2.0 pyhd3eb1b0_0 tokenizers 0.13.2 py310he7d60b5_1 toml 0.10.2 pyhd3eb1b0_0 tomli 2.0.1 py310h06a4308_0 tomlkit 0.11.1 py310h06a4308_0 toolz 0.12.0 py310h06a4308_0 torch 2.0.1+cu118 pypi_0 pypi torchaudio 2.0.2+cu118 pypi_0 pypi torchvision 0.15.2+cu118 pypi_0 pypi tornado 6.3.2 py310h5eee18b_0 tqdm 4.65.0 py310h2f386ee_0 traitlets 5.7.1 py310h06a4308_0 transformers 4.32.0.dev0 pypi_0 pypi triton 2.0.0 pypi_0 pypi twisted 22.10.0 py310h5eee18b_0 typing-extensions 4.7.1 py310h06a4308_0 typing-inspect 0.9.0 pypi_0 pypi typing_extensions 4.7.1 py310h06a4308_0 tzdata 2023c h04d1e81_0 uc-micro-py 1.0.1 py310h06a4308_0 ujson 5.4.0 py310h6a678d5_0 unidecode 1.2.0 pyhd3eb1b0_0 unixodbc 2.3.11 h5eee18b_0 urllib3 1.26.16 py310h06a4308_0 utf8proc 2.6.1 h27cfd23_0 uvicorn 0.23.2 pypi_0 pypi w3lib 1.21.0 pyhd3eb1b0_0 wandb 0.15.8 pypi_0 pypi watchdog 2.1.6 py310h06a4308_0 wcwidth 0.2.5 pyhd3eb1b0_0 webencodings 0.5.1 py310h06a4308_1 websocket-client 0.58.0 py310h06a4308_4 websockets 11.0.3 pypi_0 pypi werkzeug 2.2.3 py310h06a4308_0 whatthepatch 1.0.2 py310h06a4308_0 wheel 0.38.4 py310h06a4308_0 widgetsnbextension 4.0.5 py310h06a4308_0 wrapt 1.14.1 py310h5eee18b_0 wurlitzer 3.0.2 py310h06a4308_0 xarray 2023.6.0 py310h06a4308_0 xformers 0.0.21+2d3a221.d20230806 pypi_0 pypi xxhash 0.8.0 h7f8727e_3 xyzservices 2022.9.0 py310h06a4308_1 xz 5.4.2 h5eee18b_0 y-py 0.5.9 py310h52d8a92_0 yaml 0.2.5 h7b6447c_0 yapf 0.31.0 pyhd3eb1b0_0 yarl 1.8.1 py310h5eee18b_0 ypy-websocket 0.8.2 py310h06a4308_0 zeromq 4.3.4 h2531618_0 zfp 0.5.5 h295c915_6 zict 2.2.0 py310h06a4308_0 zipp 3.11.0 py310h06a4308_0 zlib 1.2.13 h5eee18b_0 zlib-ng 2.0.7 h5eee18b_0 zope 1.0 py310h06a4308_1 zope.interface 5.4.0 py310h7f8727e_0 zstd 1.5.5 hc292b87_0 ```
pip list ``` Package Version ----------------------------- ------------------------ absl-py 1.4.0 accelerate 0.21.0 aiobotocore 2.5.0 aiofiles 22.1.0 aiohttp 3.8.3 aioitertools 0.7.1 aiosignal 1.2.0 aiosqlite 0.18.0 alabaster 0.7.12 altair 5.0.1 anyio 3.5.0 appdirs 1.4.4 argon2-cffi 21.3.0 argon2-cffi-bindings 21.2.0 arrow 1.2.3 astroid 2.14.2 astropy 5.1 asttokens 2.0.5 async-timeout 4.0.2 atomicwrites 1.4.0 attrs 22.1.0 auto-gptq 0.3.0+cu117 Automat 20.2.0 autopep8 1.6.0 Babel 2.11.0 backcall 0.2.0 bcrypt 3.2.0 beautifulsoup4 4.12.2 binaryornot 0.4.4 bitsandbytes 0.41.1 black 0.0 bleach 4.1.0 bokeh 3.2.1 botocore 1.29.76 Bottleneck 1.3.5 brotlipy 0.7.0 cachetools 5.3.1 certifi 2023.7.22 cffi 1.15.1 chardet 4.0.0 charset-normalizer 2.0.4 click 8.0.4 cloudpickle 2.2.1 cmake 3.25.0 colorama 0.4.6 colorcet 3.0.1 comm 0.1.2 constantly 15.1.0 contourpy 1.0.5 cookiecutter 1.7.3 cryptography 41.0.2 cssselect 1.1.0 cycler 0.11.0 cytoolz 0.12.0 daal4py 2023.1.1 dask 2023.6.0 datasets 2.12.0 datashader 0.15.1 datashape 0.5.4 debugpy 1.6.7 decorator 5.1.1 defusedxml 0.7.1 diff-match-patch 20200713 dill 0.3.6 diskcache 5.6.1 distributed 2023.6.0 docker-pycreds 0.4.0 docstring-to-markdown 0.11 docutils 0.18.1 einops 0.6.1 entrypoints 0.4 et-xmlfile 1.1.0 exceptiongroup 1.0.4 executing 0.8.3 exllama 0.0.10+cu117 fastapi 0.95.2 fastjsonschema 2.16.2 ffmpy 0.3.1 filelock 3.9.0 flake8 6.0.0 flash-attn 2.0.4 Flask 2.2.2 fonttools 4.25.0 frozenlist 1.3.3 fsspec 2023.4.0 gensim 4.3.0 gitdb 4.0.10 GitPython 3.1.32 gmpy2 2.1.2 google-auth 2.22.0 google-auth-oauthlib 1.0.0 gradio 3.33.1 gradio_client 0.2.5 greenlet 2.0.1 grpcio 1.56.2 h11 0.14.0 h5py 3.7.0 HeapDict 1.0.1 holoviews 1.17.0 httpcore 0.17.3 httpx 0.24.1 huggingface-hub 0.15.1 hvplot 0.8.4 hyperlink 21.0.0 idna 3.4 imagecodecs 2021.8.26 imageio 2.31.1 imagesize 1.4.1 imbalanced-learn 0.10.1 importlib-metadata 6.0.0 incremental 21.3.0 inflection 0.5.1 iniconfig 1.1.1 intake 0.6.8 intervaltree 3.1.0 ipykernel 6.19.2 ipython 8.12.0 ipython-genutils 0.2.0 ipywidgets 8.0.4 isort 5.9.3 itemadapter 0.3.0 itemloaders 1.0.4 itsdangerous 2.0.1 jaraco.classes 3.2.1 jedi 0.18.1 jeepney 0.7.1 jellyfish 0.9.0 Jinja2 3.1.2 jinja2-time 0.2.0 jmespath 0.10.0 joblib 1.2.0 json5 0.9.6 jsonschema 4.17.3 jupyter 1.0.0 jupyter_client 7.4.9 jupyter-console 6.6.3 jupyter_core 5.3.0 jupyter-events 0.6.3 jupyter-server 1.23.4 jupyter_server_fileid 0.9.0 jupyter_server_ydoc 0.8.0 jupyter-ydoc 0.2.4 jupyterlab 3.6.3 jupyterlab-pygments 0.1.2 jupyterlab_server 2.22.0 jupyterlab-widgets 3.0.5 keyring 23.13.1 kiwisolver 1.4.4 lazy_loader 0.2 lazy-object-proxy 1.6.0 linkify-it-py 2.0.0 lit 15.0.7 llama-cpp-python 0.1.77 llama-cpp-python-cuda 0.1.77+cu117 llvmlite 0.40.0 lmdb 1.4.1 locket 1.0.0 lxml 4.9.1 lz4 4.3.2 Markdown 3.4.1 markdown-it-py 2.2.0 MarkupSafe 2.1.1 matplotlib 3.7.1 matplotlib-inline 0.1.6 mccabe 0.7.0 mdit-py-plugins 0.3.0 mdurl 0.1.0 mistune 0.8.4 mkl-fft 1.3.6 mkl-random 1.2.2 mkl-service 2.4.0 more-itertools 8.12.0 mpmath 1.3.0 msgpack 1.0.3 multidict 6.0.2 multipledispatch 0.6.0 multiprocess 0.70.14 munkres 1.1.4 mypy-extensions 0.4.3 nbclassic 0.5.5 nbclient 0.5.13 nbconvert 6.5.4 nbformat 5.7.0 nest-asyncio 1.5.6 networkx 3.1 ninja 1.11.1 nltk 3.8.1 notebook 6.5.4 notebook_shim 0.2.2 numba 0.57.0 numexpr 2.8.4 numpy 1.24.3 numpydoc 1.5.0 oauthlib 3.2.2 opencv-python 4.8.0.74 openpyxl 3.0.10 orjson 3.9.3 packaging 23.0 pandas 1.5.3 pandocfilters 1.5.0 panel 1.2.1 param 1.13.0 parsel 1.6.0 parso 0.8.3 partd 1.2.0 pathspec 0.10.3 pathtools 0.1.2 patsy 0.5.3 peft 0.5.0.dev0 pep8 1.7.1 pexpect 4.8.0 pickleshare 0.7.5 Pillow 10.0.0 pip 23.2.1 platformdirs 2.5.2 plotly 5.9.0 pluggy 1.0.0 ply 3.11 pooch 1.4.0 poyo 0.5.0 prometheus-client 0.14.1 prompt-toolkit 3.0.36 Protego 0.1.16 protobuf 4.23.4 psutil 5.9.0 ptyprocess 0.7.0 pure-eval 0.2.2 py-cpuinfo 8.0.0 pyarrow 11.0.0 pyasn1 0.4.8 pyasn1-modules 0.2.8 pycodestyle 2.10.0 pycparser 2.21 pyct 0.5.0 pycurl 7.45.2 pydantic 1.10.12 PyDispatcher 2.0.5 pydocstyle 6.3.0 pydub 0.25.1 pyerfa 2.0.0 pyflakes 3.0.1 Pygments 2.15.1 pylint 2.16.2 pylint-venv 2.3.0 pyls-spyder 0.4.0 pyodbc 4.0.34 pyOpenSSL 23.2.0 pyparsing 3.0.9 PyQt5-sip 12.11.0 pyre-extensions 0.0.29 pyrsistent 0.18.0 PySocks 1.7.1 pytest 7.4.0 python-dateutil 2.8.2 python-json-logger 2.0.7 python-lsp-black 1.2.1 python-lsp-jsonrpc 1.0.0 python-lsp-server 1.7.2 python-multipart 0.0.6 python-slugify 5.0.2 python-snappy 0.6.1 pytoolconfig 1.2.5 pytz 2022.7 pyviz-comms 2.3.0 PyWavelets 1.4.1 pyxdg 0.27 PyYAML 6.0 pyzmq 23.2.0 QDarkStyle 3.0.2 qstylizer 0.2.2 QtAwesome 1.2.2 qtconsole 5.4.2 QtPy 2.2.0 queuelib 1.5.0 regex 2022.7.9 requests 2.31.0 requests-file 1.5.1 requests-oauthlib 1.3.1 responses 0.13.3 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rope 1.7.0 rouge 1.0.1 rsa 4.9 Rtree 1.0.1 s3fs 2023.4.0 sacremoses 0.0.43 safetensors 0.3.1 scikit-image 0.20.0 scikit-learn 1.3.0 scikit-learn-intelex 20230426.111436 scipy 1.10.1 Scrapy 2.8.0 seaborn 0.12.2 SecretStorage 3.3.1 semantic-version 2.10.0 Send2Trash 1.8.0 sentencepiece 0.1.99 sentry-sdk 1.29.2 service-identity 18.1.0 setproctitle 1.3.2 setuptools 68.0.0 sip 6.6.2 six 1.16.0 smart-open 5.2.1 smmap 5.0.0 sniffio 1.2.0 snowballstemmer 2.2.0 sortedcontainers 2.4.0 soupsieve 2.4 Sphinx 5.0.2 sphinxcontrib-applehelp 1.0.2 sphinxcontrib-devhelp 1.0.2 sphinxcontrib-htmlhelp 2.0.0 sphinxcontrib-jsmath 1.0.1 sphinxcontrib-qthelp 1.0.3 sphinxcontrib-serializinghtml 1.1.5 spyder 5.4.3 spyder-kernels 2.4.3 SQLAlchemy 1.4.39 stack-data 0.2.0 starlette 0.27.0 statsmodels 0.14.0 sympy 1.11.1 tables 3.8.0 tabulate 0.8.10 TBB 0.2 tblib 1.7.0 tenacity 8.2.2 tensorboard 2.13.0 tensorboard-data-server 0.7.1 terminado 0.17.1 text-unidecode 1.3 textdistance 4.2.1 threadpoolctl 2.2.0 three-merge 0.1.1 tifffile 2021.7.2 tinycss2 1.2.1 tldextract 3.2.0 tokenizers 0.13.2 toml 0.10.2 tomli 2.0.1 tomlkit 0.11.1 toolz 0.12.0 torch 2.0.1+cu118 torchaudio 2.0.2+cu118 torchvision 0.15.2+cu118 tornado 6.3.2 tqdm 4.65.0 traitlets 5.7.1 transformers 4.32.0.dev0 triton 2.0.0 Twisted 22.10.0 typing_extensions 4.7.1 typing-inspect 0.9.0 uc-micro-py 1.0.1 ujson 5.4.0 Unidecode 1.2.0 urllib3 1.26.16 uvicorn 0.23.2 w3lib 1.21.0 wandb 0.15.8 watchdog 2.1.6 wcwidth 0.2.5 webencodings 0.5.1 websocket-client 0.58.0 websockets 11.0.3 Werkzeug 2.2.3 whatthepatch 1.0.2 wheel 0.38.4 widgetsnbextension 4.0.5 wrapt 1.14.1 wurlitzer 3.0.2 xarray 2023.6.0 xformers 0.0.21+2d3a221.d20230806 xxhash 2.0.2 xyzservices 2022.9.0 y-py 0.5.9 yapf 0.31.0 yarl 1.8.1 ypy-websocket 0.8.2 zict 2.2.0 zipp 3.11.0 zope.interface 5.4.0 ```

Steps taken to install with success:

sudo apt install wget

wget https://repo.anaconda.com/archive/Anaconda3-2023.07-1-Linux-x86_64.sh

sudo sh Anaconda3-2023.07-1-Linux-x86_64.sh (if you get stuck in the user agreement text, use 'q')

sudo apt install git

sudo apt install build-essential

pip install ninja

sudo apt install libxml2

conda config --add channels conda-forge

conda install -c conda-forge clang (I was getting g++ errors while compiling, not 100% this fixed it.)

conda install -c conda-forge clang-tools

conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit

conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc

sudo ln -s /usr/lib/wsl/lib/libcuda.so.1 /usr/local/cuda-11.8/lib64/libcuda.so

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Compile+Install xformers (30min~) || (Install with Wheel option: https://pypi.org/project/xformers/0.0.21.dev577/#files)

git clone https://github.com/facebookresearch/xformers

git submodule update --init --recursive

cd xformers

MAX_JOBS=4 python setup.py build (MAX_JOBS will dictate how many cores are assigned for compilation, more cores requires more memory. 122GB for 16 cores.)

python setup.py install

Compile+Install flash-attention (15min~)

git clone https://github.com/Dao-AILab/flash-attention

cd flash-attention/

git submodule update --init --recursive

MAX_JOBS=4 python setup.py build  (MAX_JOBS will dictate how many cores are assigned for compilation, more cores requires more memory. 96GB for 16 cores.)

python setup.py install 

And Voilà!

Benchmark works ```flash-attention/benchmarks$ python benchmark_flash_attention.py - Forward pass fn_amp(*inputs, **kwinputs) 1.44 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 6.91 ms 1 measurement, 10 runs , 8 threads ### causal=False, headdim=64, batch_size=32, seqlen=512 ### Flash2 fwd: 61.04 TFLOPs/s, bwd: 45.51 TFLOPs/s, fwd + bwd: 49.08 TFLOPs/s Pytorch fwd: 15.81 TFLOPs/s, bwd: 18.58 TFLOPs/s, fwd + bwd: 17.70 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 47.66 TFLOPs/s, bwd: 24.86 TFLOPs/s, fwd + bwd: 28.80 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 3.70 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 12.06 ms 1 measurement, 10 runs , 8 threads ### causal=False, headdim=64, batch_size=16, seqlen=1024 ### Flash2 fwd: 50.46 TFLOPs/s, bwd: 50.68 TFLOPs/s, fwd + bwd: 50.62 TFLOPs/s Pytorch fwd: 18.48 TFLOPs/s, bwd: 20.86 TFLOPs/s, fwd + bwd: 20.12 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 37.12 TFLOPs/s, bwd: 28.48 TFLOPs/s, fwd + bwd: 30.51 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 6.79 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 23.08 ms 1 measurement, 10 runs , 8 threads ### causal=False, headdim=64, batch_size=8, seqlen=2048 ### Flash2 fwd: 54.76 TFLOPs/s, bwd: 54.58 TFLOPs/s, fwd + bwd: 54.63 TFLOPs/s Pytorch fwd: 16.22 TFLOPs/s, bwd: 20.44 TFLOPs/s, fwd + bwd: 19.02 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 40.48 TFLOPs/s, bwd: 29.77 TFLOPs/s, fwd + bwd: 32.20 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 17.51 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 51.09 ms 1 measurement, 10 runs , 8 threads ### causal=False, headdim=64, batch_size=4, seqlen=4096 ### Flash2 fwd: 57.67 TFLOPs/s, bwd: 55.90 TFLOPs/s, fwd + bwd: 56.40 TFLOPs/s Pytorch fwd: 7.02 TFLOPs/s, bwd: 9.79 TFLOPs/s, fwd + bwd: 8.79 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 31.39 TFLOPs/s, bwd: 26.90 TFLOPs/s, fwd + bwd: 28.05 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 36.85 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 105.14 ms 1 measurement, 10 runs , 8 threads ### causal=False, headdim=64, batch_size=2, seqlen=8192 ### Flash2 fwd: 52.16 TFLOPs/s, bwd: 56.50 TFLOPs/s, fwd + bwd: 55.19 TFLOPs/s Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 29.83 TFLOPs/s, bwd: 26.15 TFLOPs/s, fwd + bwd: 27.10 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 63.98 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 201.09 ms 1 measurement, 10 runs , 8 threads ### causal=False, headdim=64, batch_size=1, seqlen=16384 ### Flash2 fwd: 12.44 TFLOPs/s, bwd: 58.29 TFLOPs/s, fwd + bwd: 28.40 TFLOPs/s Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 34.37 TFLOPs/s, bwd: 27.34 TFLOPs/s, fwd + bwd: 29.04 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 2.54 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 15.47 ms 1 measurement, 10 runs , 8 threads ### causal=False, headdim=128, batch_size=32, seqlen=512 ### Flash2 fwd: 39.10 TFLOPs/s, bwd: 33.53 TFLOPs/s, fwd + bwd: 34.95 TFLOPs/s Pytorch fwd: 11.05 TFLOPs/s, bwd: 16.20 TFLOPs/s, fwd + bwd: 14.29 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 27.00 TFLOPs/s, bwd: 11.10 TFLOPs/s, fwd + bwd: 13.35 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 4.41 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 27.73 ms 1 measurement, 10 runs , 8 threads ### causal=False, headdim=128, batch_size=16, seqlen=1024 ### Flash2 fwd: 50.02 TFLOPs/s, bwd: 43.26 TFLOPs/s, fwd + bwd: 44.99 TFLOPs/s Pytorch fwd: 17.46 TFLOPs/s, bwd: 22.96 TFLOPs/s, fwd + bwd: 21.06 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 31.14 TFLOPs/s, bwd: 12.39 TFLOPs/s, fwd + bwd: 14.96 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 7.38 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 55.73 ms 1 measurement, 10 runs , 8 threads ### causal=False, headdim=128, batch_size=8, seqlen=2048 ### Flash2 fwd: 56.53 TFLOPs/s, bwd: 53.46 TFLOPs/s, fwd + bwd: 54.30 TFLOPs/s Pytorch fwd: 24.90 TFLOPs/s, bwd: 32.85 TFLOPs/s, fwd + bwd: 30.11 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 37.25 TFLOPs/s, bwd: 12.33 TFLOPs/s, fwd + bwd: 15.24 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 14.11 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 126.77 ms 1 measurement, 10 runs , 8 threads ### causal=False, headdim=128, batch_size=4, seqlen=4096 ### Flash2 fwd: 56.96 TFLOPs/s, bwd: 55.50 TFLOPs/s, fwd + bwd: 55.91 TFLOPs/s Pytorch fwd: 28.56 TFLOPs/s, bwd: 34.02 TFLOPs/s, fwd + bwd: 32.26 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 38.97 TFLOPs/s, bwd: 10.84 TFLOPs/s, fwd + bwd: 13.66 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 27.05 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 250.84 ms 1 measurement, 10 runs , 8 threads ### causal=False, headdim=128, batch_size=2, seqlen=8192 ### Flash2 fwd: 58.89 TFLOPs/s, bwd: 56.20 TFLOPs/s, fwd + bwd: 56.94 TFLOPs/s Pytorch fwd: 29.31 TFLOPs/s, bwd: 34.35 TFLOPs/s, fwd + bwd: 32.74 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 40.64 TFLOPs/s, bwd: 10.96 TFLOPs/s, fwd + bwd: 13.85 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 58.42 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 526.64 ms 1 measurement, 10 runs , 8 threads ### causal=False, headdim=128, batch_size=1, seqlen=16384 ### Flash2 fwd: 59.04 TFLOPs/s, bwd: 56.82 TFLOPs/s, fwd + bwd: 57.44 TFLOPs/s Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 37.64 TFLOPs/s, bwd: 10.44 TFLOPs/s, fwd + bwd: 13.16 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 1.07 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 3.94 ms 1 measurement, 10 runs , 8 threads ### causal=True, headdim=64, batch_size=32, seqlen=512 ### Flash2 fwd: 42.95 TFLOPs/s, bwd: 30.47 TFLOPs/s, fwd + bwd: 33.23 TFLOPs/s Pytorch fwd: 5.68 TFLOPs/s, bwd: 9.05 TFLOPs/s, fwd + bwd: 7.74 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 32.26 TFLOPs/s, bwd: 21.79 TFLOPs/s, fwd + bwd: 24.02 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 1.57 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 6.96 ms 1 measurement, 10 runs , 8 threads ### causal=True, headdim=64, batch_size=16, seqlen=1024 ### Flash2 fwd: 46.05 TFLOPs/s, bwd: 40.01 TFLOPs/s, fwd + bwd: 41.57 TFLOPs/s Pytorch fwd: 5.89 TFLOPs/s, bwd: 10.38 TFLOPs/s, fwd + bwd: 8.52 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 43.75 TFLOPs/s, bwd: 24.70 TFLOPs/s, fwd + bwd: 28.21 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 3.07 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 12.83 ms 1 measurement, 10 runs , 8 threads ### causal=True, headdim=64, batch_size=8, seqlen=2048 ### Flash2 fwd: 50.89 TFLOPs/s, bwd: 47.35 TFLOPs/s, fwd + bwd: 48.31 TFLOPs/s Pytorch fwd: 5.84 TFLOPs/s, bwd: 10.29 TFLOPs/s, fwd + bwd: 8.45 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 44.71 TFLOPs/s, bwd: 26.78 TFLOPs/s, fwd + bwd: 30.24 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 6.33 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 25.08 ms 1 measurement, 10 runs , 8 threads ### causal=True, headdim=64, batch_size=4, seqlen=4096 ### Flash2 fwd: 53.13 TFLOPs/s, bwd: 51.93 TFLOPs/s, fwd + bwd: 52.27 TFLOPs/s Pytorch fwd: 6.14 TFLOPs/s, bwd: 10.69 TFLOPs/s, fwd + bwd: 8.82 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 43.41 TFLOPs/s, bwd: 27.40 TFLOPs/s, fwd + bwd: 30.63 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 16.82 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 56.23 ms 1 measurement, 10 runs , 8 threads ### causal=True, headdim=64, batch_size=2, seqlen=8192 ### Flash2 fwd: 54.96 TFLOPs/s, bwd: 54.43 TFLOPs/s, fwd + bwd: 54.58 TFLOPs/s Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 32.68 TFLOPs/s, bwd: 24.44 TFLOPs/s, fwd + bwd: 26.34 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 33.91 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 113.57 ms 1 measurement, 10 runs , 8 threads ### causal=True, headdim=64, batch_size=1, seqlen=16384 ### Flash2 fwd: 13.41 TFLOPs/s, bwd: 49.96 TFLOPs/s, fwd + bwd: 28.08 TFLOPs/s Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 32.43 TFLOPs/s, bwd: 24.20 TFLOPs/s, fwd + bwd: 26.09 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 1.30 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 8.24 ms 1 measurement, 10 runs , 8 threads ### causal=True, headdim=128, batch_size=32, seqlen=512 ### Flash2 fwd: 39.87 TFLOPs/s, bwd: 27.40 TFLOPs/s, fwd + bwd: 30.09 TFLOPs/s Pytorch fwd: 6.06 TFLOPs/s, bwd: 8.46 TFLOPs/s, fwd + bwd: 7.60 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 26.46 TFLOPs/s, bwd: 10.43 TFLOPs/s, fwd + bwd: 12.61 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 2.06 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 14.41 ms 1 measurement, 10 runs , 8 threads ### causal=True, headdim=128, batch_size=16, seqlen=1024 ### Flash2 fwd: 48.39 TFLOPs/s, bwd: 37.02 TFLOPs/s, fwd + bwd: 39.68 TFLOPs/s Pytorch fwd: 5.93 TFLOPs/s, bwd: 15.80 TFLOPs/s, fwd + bwd: 10.71 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 33.39 TFLOPs/s, bwd: 11.93 TFLOPs/s, fwd + bwd: 14.61 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 3.73 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 27.85 ms 1 measurement, 10 runs , 8 threads ### causal=True, headdim=128, batch_size=8, seqlen=2048 ### Flash2 fwd: 51.95 TFLOPs/s, bwd: 48.02 TFLOPs/s, fwd + bwd: 49.08 TFLOPs/s Pytorch fwd: 9.68 TFLOPs/s, bwd: 16.40 TFLOPs/s, fwd + bwd: 13.69 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 36.82 TFLOPs/s, bwd: 12.34 TFLOPs/s, fwd + bwd: 15.23 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 7.20 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 66.15 ms 1 measurement, 10 runs , 8 threads ### causal=True, headdim=128, batch_size=4, seqlen=4096 ### Flash2 fwd: 52.58 TFLOPs/s, bwd: 51.80 TFLOPs/s, fwd + bwd: 52.02 TFLOPs/s Pytorch fwd: 10.16 TFLOPs/s, bwd: 17.19 TFLOPs/s, fwd + bwd: 14.35 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 38.16 TFLOPs/s, bwd: 10.39 TFLOPs/s, fwd + bwd: 13.12 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 13.94 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 127.61 ms 1 measurement, 10 runs , 8 threads ### causal=True, headdim=128, batch_size=2, seqlen=8192 ### Flash2 fwd: 51.47 TFLOPs/s, bwd: 52.68 TFLOPs/s, fwd + bwd: 52.33 TFLOPs/s Pytorch fwd: 9.87 TFLOPs/s, bwd: 17.38 TFLOPs/s, fwd + bwd: 14.28 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 39.44 TFLOPs/s, bwd: 10.77 TFLOPs/s, fwd + bwd: 13.59 TFLOPs/s - Forward pass fn_amp(*inputs, **kwinputs) 27.67 ms 1 measurement, 10 runs , 8 threads - Backward pass f(*inputs, y=y, grad=grad) 268.94 ms 1 measurement, 10 runs , 8 threads ### causal=True, headdim=128, batch_size=1, seqlen=16384 ### Flash2 fwd: 52.00 TFLOPs/s, bwd: 55.48 TFLOPs/s, fwd + bwd: 54.44 TFLOPs/s Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s xformers fwd: 39.73 TFLOPs/s, bwd: 10.22 TFLOPs/s, fwd + bwd: 12.97 TFLOPs/s```
Short PC Specs - CPU: AMD Ryzen 7 5800x - MB: Gigabyte x570 Master - RAM: 128GB - GPU: x2 Gigabyte RTX 3090 +NVLINK (using PCIe extensions cable for the 4 slots spacing required) - Storage: NVMe SSD - Windows 10 - WSL2 Debian VERSION="12 (bookworm)" - Windows Nvidia Drivers: 536.23 - Windows Nvidia Cuda: 11.7, 11.8, 12.1
YuanWind commented 1 year ago

如果是缺少 cutlass.h 头文件,可以把 https://github.com/NVIDIA/cutlass/tree/main/include 里边的 cutlass 和 cute 文件夹copy到csrc/flash_attn下面

JoshuaPinaca commented 2 months ago

def fixed_get_imports(filename: str | os.PathLike) -> list[str]: """Workaround for FlashAttention""" if os.path.basename(filename) != "modeling_florence2.py": return get_imports(filename) imports = get_imports(filename) imports.remove("flash_attn") return imports

this work for me while loading Florence-2 model
Chord-Chen-30 commented 3 weeks ago

如果是缺少 cutlass.h 头文件,可以把 https://github.com/NVIDIA/cutlass/tree/main/include 里边的 cutlass 和 cute 文件夹copy到csrc/flash_attn下面

Solved my problem! Amazing! How do you know that?