Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.18k stars 1.32k forks source link

How to install latest flash-atten on NGC containers #734

Open tcapelle opened 10 months ago

tcapelle commented 10 months ago

What is the best way to do this?

Currently I am unable to install on 23-12 or 11.

The image comes with flash_atten 2.0.4 by default.

# My docker file...
FROM nvcr.io/nvidia/pytorch:23.11-py3

## Flash Atten 2
RUN FLASH_ATTENTION_FORCE_BUILD=TRUE pip install -U flash-attn

and the error:

--expt-extended-lambda --use_fast_math -gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90 --threads 4 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=flash_attn_2_cuda -D_GLIBCXX_USE_CXX11_ABI=1
      ninja: build stopped: subcommand failed.
      Traceback (most recent call last):
        File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 2100, in _run_ninja_build
          subprocess.run(
        File "/usr/lib/python3.10/subprocess.py", line 526, in run
          raise CalledProcessError(retcode, process.args,
      subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

      The above exception was the direct cause of the following exception:

      Traceback (most recent call last):
        File "<string>", line 2, in <module>
        File "<pip-setuptools-caller>", line 34, in <module>
        File "/tmp/pip-install-bza7pl88/flash-attn_acb2fcb058194bbf8ae0b2ce8aaa8f52/setup.py", line 285, in <module>
          setup(
        File "/usr/local/lib/python3.10/dist-packages/setuptools/__init__.py", line 103, in setup
          return distutils.core.setup(**attrs)
        File "/usr/lib/python3.10/distutils/core.py", line 148, in setup
          dist.run_commands()
        File "/usr/lib/python3.10/distutils/dist.py", line 966, in run_commands
          self.run_command(cmd)
        File "/usr/local/lib/python3.10/dist-packages/setuptools/dist.py", line 989, in run_command
          super().run_command(command)
        File "/usr/lib/python3.10/distutils/dist.py", line 985, in run_command
          cmd_obj.run()
        File "/tmp/pip-install-bza7pl88/flash-attn_acb2fcb058194bbf8ae0b2ce8aaa8f52/setup.py", line 260, in run
          return super().run()
        File "/usr/local/lib/python3.10/dist-packages/wheel/bdist_wheel.py", line 369, in run
          self.run_command("build")
        File "/usr/lib/python3.10/distutils/cmd.py", line 313, in run_command
          self.distribution.run_command(command)
        File "/usr/local/lib/python3.10/dist-packages/setuptools/dist.py", line 989, in run_command
          super().run_command(command)
        File "/usr/lib/python3.10/distutils/dist.py", line 985, in run_command
          cmd_obj.run()
        File "/usr/lib/python3.10/distutils/command/build.py", line 135, in run
          self.run_command(cmd_name)
        File "/usr/lib/python3.10/distutils/cmd.py", line 313, in run_command
          self.distribution.run_command(command)
        File "/usr/local/lib/python3.10/dist-packages/setuptools/dist.py", line 989, in run_command
          super().run_command(command)
        File "/usr/lib/python3.10/distutils/dist.py", line 985, in run_command
          cmd_obj.run()
        File "/usr/local/lib/python3.10/dist-packages/setuptools/command/build_ext.py", line 88, in run
          _build_ext.run(self)
        File "/usr/lib/python3.10/distutils/command/build_ext.py", line 340, in run
          self.build_extensions()
        File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 873, in build_extensions
          build_ext.build_extensions(self)
        File "/usr/lib/python3.10/distutils/command/build_ext.py", line 449, in build_extensions
          self._build_extensions_serial()
        File "/usr/lib/python3.10/distutils/command/build_ext.py", line 474, in _build_extensions_serial
          self.build_extension(ext)
        File "/usr/local/lib/python3.10/dist-packages/setuptools/command/build_ext.py", line 249, in build_extension
          _build_ext.build_extension(self, ext)
        File "/usr/local/lib/python3.10/dist-packages/Cython/Distutils/build_ext.py", line 135, in build_extension
          super(build_ext, self).build_extension(ext)
        File "/usr/lib/python3.10/distutils/command/build_ext.py", line 529, in build_extension
          objects = self.compiler.compile(sources,
        File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 686, in unix_wrap_ninja_compile
          _write_ninja_file_and_compile_objects(
        File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 1774, in _write_ninja_file_and_compile_objects
          _run_ninja_build(
        File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 2116, in _run_ninja_build
          raise RuntimeError(message) from e
      RuntimeError: Error compiling objects for extension
tcapelle commented 10 months ago

I have also tried RUN pip install flash-attn --no-build-isolation but even if it installs, then it will raise

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspace/mixtral/simple_inference.py", line 7, in <module>
    from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
  File "<frozen importlib._bootstrap>", line 1075, in _handle_fromlist
  File "/usr/local/lib/python3.10/dist-packages/transformers/utils/import_utils.py", line 1372, in __getattr__
    module = self._get_module(self._class_to_module[name])
  File "/usr/local/lib/python3.10/dist-packages/transformers/utils/import_utils.py", line 1384, in _get_module
    raise RuntimeError(
RuntimeError: Failed to import transformers.training_args because of the following error (look up to see its traceback):
/usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv
tridao commented 10 months ago

It's because of torch version change. nvcr pytorch 23.12 should work with flash-attn v2.4.0.post1 now.

tcapelle commented 9 months ago

nop, not yet working...