Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.37k stars 1.22k forks source link

Failed to build dropout-layer-norm #587

Open vgoklani opened 11 months ago

vgoklani commented 11 months ago

Hey there,

I'm not able to build the dropout-layer-norm.

I used this Docker image: nvcr.io/nvidia/pytorch:23.09-py3 and then installed the flash-attention components via:

flash_attn_version=2.3.0

pip install flash-attn==${flash_attn_version}

cd

git clone https://github.com/HazyResearch/flash-attention \
    && cd flash-attention && git checkout v${flash_attn_version} \
    && cd csrc/fused_softmax && pip install . && cd ../../ \
    && cd csrc/rotary && pip install . && cd ../../ \
    && cd csrc/xentropy && pip install . && cd ../../ \
    && cd csrc/layer_norm && pip install . && cd ../../ \
    && cd csrc/fused_dense_lib && pip install . && cd ../../ \
    && cd csrc/ft_attention && pip install . && cd ../../ \
    && cd .. && rm -rf flash-attention

this is a subset of the traceback:

  The above exception was the direct cause of the following exception:                                                                                                                                                                                                                                                                    [380/1842]

  Traceback (most recent call last):
    File "<string>", line 2, in <module>
    File "<pip-setuptools-caller>", line 34, in <module>
    File "/root/temp/flash-attention/csrc/layer_norm/setup.py", line 198, in <module>
      setup(
    File "/usr/local/lib/python3.10/dist-packages/setuptools/__init__.py", line 103, in setup
      return distutils.core.setup(**attrs)
    File "/usr/lib/python3.10/distutils/core.py", line 148, in setup
      dist.run_commands()
    File "/usr/lib/python3.10/distutils/dist.py", line 966, in run_commands
      self.run_command(cmd)
    File "/usr/local/lib/python3.10/dist-packages/setuptools/dist.py", line 989, in run_command
      super().run_command(command)
    File "/usr/lib/python3.10/distutils/dist.py", line 985, in run_command
      cmd_obj.run()
    File "/usr/local/lib/python3.10/dist-packages/wheel/bdist_wheel.py", line 364, in run
      self.run_command("build")
    File "/usr/lib/python3.10/distutils/cmd.py", line 313, in run_command
      self.distribution.run_command(command)
    File "/usr/local/lib/python3.10/dist-packages/setuptools/dist.py", line 989, in run_command
      super().run_command(command)
    File "/usr/lib/python3.10/distutils/dist.py", line 985, in run_command
      cmd_obj.run()
    File "/usr/lib/python3.10/distutils/command/build.py", line 135, in run
      self.run_command(cmd_name)
    File "/usr/lib/python3.10/distutils/cmd.py", line 313, in run_command
      self.distribution.run_command(command)
    File "/usr/local/lib/python3.10/dist-packages/setuptools/dist.py", line 989, in run_command
      super().run_command(command)
    File "/usr/lib/python3.10/distutils/dist.py", line 985, in run_command
      cmd_obj.run()
    File "/usr/local/lib/python3.10/dist-packages/setuptools/command/build_ext.py", line 88, in run
      _build_ext.run(self)
    File "/usr/lib/python3.10/distutils/command/build_ext.py", line 340, in run
      self.build_extensions()
    File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 865, in build_extensions
      build_ext.build_extensions(self)
    File "/usr/lib/python3.10/distutils/command/build_ext.py", line 449, in build_extensions
      self._build_extensions_serial()
    File "/usr/lib/python3.10/distutils/command/build_ext.py", line 474, in _build_extensions_serial
      self.build_extension(ext)
    File "/usr/local/lib/python3.10/dist-packages/setuptools/command/build_ext.py", line 249, in build_extension
      _build_ext.build_extension(self, ext)
    File "/usr/local/lib/python3.10/dist-packages/Cython/Distutils/build_ext.py", line 127, in build_extension
      super(build_ext, self).build_extension(ext)
    File "/usr/lib/python3.10/distutils/command/build_ext.py", line 529, in build_extension
      objects = self.compiler.compile(sources,
    File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 678, in unix_wrap_ninja_compile
      _write_ninja_file_and_compile_objects(
    File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 1590, in _write_ninja_file_and_compile_objects
      _run_ninja_build(
    File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 1933, in _run_ninja_build
      raise RuntimeError(message) from e
  RuntimeError: Error compiling objects for extension
  [end of output]

The other modules all built successfully.

these are my device specs:

gpu:0 - NVIDIA RTX 6000 Ada Generation
device_properties.name NVIDIA RTX 6000 Ada Generation
multi-processor-count: 142
gpu:1 - NVIDIA RTX 6000 Ada Generation
device_properties.name NVIDIA RTX 6000 Ada Generation
multi-processor-count: 142
PyTorch version:2.1.0a0+32f93b1
CUDA Version: 12.2
cuDNN version is: 8905
Arch version is: sm_52 sm_60 sm_61 sm_70 sm_72 sm_75 sm_80 sm_86 sm_87 sm_90 compute_90
ARCH LIST: ['sm_52', 'sm_60', 'sm_61', 'sm_70', 'sm_72', 'sm_75', 'sm_80', 'sm_86', 'sm_87', 'sm_90', 'compute_90']
Device Capability: (8, 9)

I was able to build everything with flash_attn_version=2.2.1 without any issues.

thanks!

===========

one quick update: I checked this whl:

and it looks like it didn't build correctly there either:

ModuleNotFoundError: No module named 'dropout_layer_norm'

the other modules all imported correctly:

import flash_attn
from flash_attn import flash_attn_func
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb_func
from flash_attn.losses.cross_entropy import CrossEntropyLoss
from flash_attn.ops.activations import swiglu as swiglu_gated
vgoklani commented 11 months ago
Lvjinhong commented 11 months ago

I've tried installing flash-attn using pip install flash-attn==2.2.1 and flash-attn==2.3. It can be seen that the installation was ultimately successful. However, when I attempt distributed training with Megatron LM, I consistently encounter the following issue :

image

Additionally, when I tried building from the source code, the issue persisted.

tridao commented 10 months ago

dropout_layer_norm is a separate extension. You don't have to use it.

vgoklani commented 10 months ago

@tridao to be clear, we want to use it :) but it's not building correctly

from above

&& cd csrc/layer_norm && pip install . && cd ../../ \
sipie800 commented 5 months ago

same here. compile failes right after the obj files are generated, or not all of them are generated , i don't know

tridao commented 5 months ago

As mentioned in https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm: As of 2024-01-05, this extension is no longer used in the FlashAttention repo. We've instead switched to a Triton-based implementation.

sipie800 commented 5 months ago

As mentioned in https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm: As of 2024-01-05, this extension is no longer used in the FlashAttention repo. We've instead switched to a Triton-based implementation.

thanks for replying. Does that mean that a model developer has to modify their usage in a model of flash_attn to use triton one? Or flash_attn will switch it internally by itself?

tridao commented 5 months ago

Internally we already use the Triton implementation for layernorm.

sipie800 commented 5 months ago

I'm using Qwen LLM in modelscope frame. Have flash_attn 2.5.6 installed.


Try importing flash-attention for faster inference... Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm

so might I just ignore the warnings above when modelscope is loading model?

tridao commented 5 months ago

Sorry I can't control what Qwen implementation uses.

sipie800 commented 5 months ago

Sorry I can't control what Qwen implementation uses.

That's true. Yet if flash_attn use triton layernorm internally, there should not be such a warnings? They are just calling layernorm, whether it's the triton one or old one? Or the triton one are not simple replacement actually.

tridao commented 5 months ago

The warning is printed from Qwen's code. I can't control that.

qyr0403 commented 2 months ago

have you solve this problem?i also meet this

htchentyut commented 1 month ago

https://github.com/OpenGVLab/InternVideo/blob/main/InternVideo2/multi_modality/INSTALL.md