Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
11.38k stars 1k forks source link

Error with Pytoch containers #813

Open parmesant opened 3 months ago

parmesant commented 3 months ago

GPU: 2x RTX 4090 Memory: 128GB CPU: 64 cores CUDA: 12.3.52 NVIDIA Driver: 545.23.08 PyTorch Container: 23.11 nvcc --version (on host machine): Cuda compilation tools, release 12.3, V12.3.107 Build cuda_12.3.r12.3/compiler.33567101_0

I am trying to fine-tune mistral 7b instruct v0.2 and am running into these errors-

First, I run into an error due to protobuf (I have 4.24.4 and the error message suggests downgrading it to 3.20.x or lower) (Upon downgrading, pip gives this error cudf 23.10.0 requires protobuf<5,>=4.21, but you have protobuf 3.20.3 which is incompatible. but things continue to work)

TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "mistral_finetuning_script.py", line 84, in <module>
    from trl import SFTTrainer
  File "./python3.10/site-packages/trl/__init__.py", line 5, in <module>
    from .core import set_seed
  File "./python3.10/site-packages/trl/core.py", line 25, in <module>
    from transformers import top_k_top_p_filtering
  File "<frozen importlib._bootstrap>", line 1075, in _handle_fromlist
  File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1355, in __getattr__
    value = getattr(module, name)
  File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1354, in __getattr__
    module = self._get_module(self._class_to_module[name])
  File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1366, in _get_module
    raise RuntimeError(
RuntimeError: Failed to import transformers.generation.utils because of the following error (look up to see its traceback):
Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

After downgrading to 3.20.3, I run into an issue with flash-attn

    trainer = SFTTrainer(
  File "./python3.10/site-packages/trl/trainer/sft_trainer.py", line 163, in __init__
    model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
  File "./python3.10/site-packages/transformers/models/auto/auto_factory.py", line 566, in from_pretrained
    return model_class.from_pretrained(
  File "./python3.10/site-packages/transformers/modeling_utils.py", line 3588, in from_pretrained
    config = cls._autoset_attn_implementation(
  File "./python3.10/site-packages/transformers/modeling_utils.py", line 1387, in _autoset_attn_implementation
  File "./python3.10/site-packages/transformers/modeling_utils.py", line 1483, in _check_and_enable_flash_attn_2
    raise ImportError(
ImportError: FlashAttention2 has been toggled on, but it cannot be used due to the following error: you need flash_attn package version to be greater or equal than 2.1.0. Detected version 2.0.4. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.

After upgrading flash-attn to 2.3.6 (latest at the time of the pytorch container release), I get this pip warning ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. transformer-engine 1.0.0+66d91d5 requires flash-attn<=2.0.4,>=1.0.6, but you have flash-attn 2.3.6 which is incompatible. and this error-

Traceback (most recent call last):
  File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1364, in _get_module
    return importlib.import_module("." + module_name, self.__name__)
  File "/usr/lib/python3.10/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "./python3.10/site-packages/transformers/generation/utils.py", line 93, in <module>
    from accelerate.hooks import AlignDevicesHook, add_hook_to_module
  File "./python3.10/site-packages/accelerate/__init__.py", line 3, in <module>
    from .accelerator import Accelerator
  File "./python3.10/site-packages/accelerate/accelerator.py", line 35, in <module>
    from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
  File "./python3.10/site-packages/accelerate/checkpointing.py", line 24, in <module>
    from .utils import (
  File "./python3.10/site-packages/accelerate/utils/__init__.py", line 153, in <module>
    from .launch import (
  File "./python3.10/site-packages/accelerate/utils/launch.py", line 33, in <module>
    from ..utils.other import is_port_in_use, merge_dicts
  File "./python3.10/site-packages/accelerate/utils/other.py", line 36, in <module>
    from .transformer_engine import convert_model
  File "./python3.10/site-packages/accelerate/utils/transformer_engine.py", line 21, in <module>
    import transformer_engine.pytorch as te
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/__init__.py", line 11, in <module>
    from .attention import DotProductAttention
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 61, in <module>
    from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/__init__.py", line 3, in <module>
    from flash_attn.flash_attn_interface import (
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 10, in <module>
    import flash_attn_2_cuda as flash_attn_cuda
ImportError: /usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "mistral_finetuning_script.py", line 84, in <module>
    from trl import SFTTrainer
  File "./python3.10/site-packages/trl/__init__.py", line 5, in <module>
    from .core import set_seed
  File "./python3.10/site-packages/trl/core.py", line 25, in <module>
    from transformers import top_k_top_p_filtering
  File "<frozen importlib._bootstrap>", line 1075, in _handle_fromlist
  File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1355, in __getattr__
    value = getattr(module, name)
  File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1354, in __getattr__
    module = self._get_module(self._class_to_module[name])
  File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1366, in _get_module
    raise RuntimeError(
RuntimeError: Failed to import transformers.generation.utils because of the following error (look up to see its traceback):
/usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv

Even if I install the latest version 2.5.2 (latest at the time of writing this), I get a similar error-

Traceback (most recent call last):
  File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1364, in _get_module
    return importlib.import_module("." + module_name, self.__name__)
  File "/usr/lib/python3.10/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "./python3.10/site-packages/transformers/generation/utils.py", line 93, in <module>
    from accelerate.hooks import AlignDevicesHook, add_hook_to_module
  File "./python3.10/site-packages/accelerate/__init__.py", line 3, in <module>
    from .accelerator import Accelerator
  File "./python3.10/site-packages/accelerate/accelerator.py", line 35, in <module>
    from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
  File "./python3.10/site-packages/accelerate/checkpointing.py", line 24, in <module>
    from .utils import (
  File "./python3.10/site-packages/accelerate/utils/__init__.py", line 153, in <module>
    from .launch import (
  File "./python3.10/site-packages/accelerate/utils/launch.py", line 33, in <module>
    from ..utils.other import is_port_in_use, merge_dicts
  File "./python3.10/site-packages/accelerate/utils/other.py", line 36, in <module>
    from .transformer_engine import convert_model
  File "./python3.10/site-packages/accelerate/utils/transformer_engine.py", line 21, in <module>
    import transformer_engine.pytorch as te
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/__init__.py", line 11, in <module>
    from .attention import DotProductAttention
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 61, in <module>
    from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/__init__.py", line 3, in <module>
    from flash_attn.flash_attn_interface import (
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 10, in <module>
    import flash_attn_2_cuda as flash_attn_cuda
ImportError: /usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops15sum_IntList_out4callERKNS_6TensorEN3c1016OptionalArrayRefIlEEbSt8optionalINS5_10ScalarTypeEERS2_

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "mistral_finetuning_script.py", line 84, in <module>
    from trl import SFTTrainer
  File "./python3.10/site-packages/trl/__init__.py", line 5, in <module>
    from .core import set_seed
  File "./python3.10/site-packages/trl/core.py", line 25, in <module>
    from transformers import top_k_top_p_filtering
  File "<frozen importlib._bootstrap>", line 1075, in _handle_fromlist
  File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1355, in __getattr__
    value = getattr(module, name)
  File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1354, in __getattr__
    module = self._get_module(self._class_to_module[name])
  File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1366, in _get_module
    raise RuntimeError(
RuntimeError: Failed to import transformers.generation.utils because of the following error (look up to see its traceback):
/usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops15sum_IntList_out4callERKNS_6TensorEN3c1016OptionalArrayRefIlEEbSt8optionalINS5_10ScalarTypeEERS2_

Some other things I've tried-

tridao commented 3 months ago

Try flash-attn 2.5.1 on nvcr 23.12 or 24.01.

tsvisab commented 1 month ago

The symbol "_ZN2at4_ops15sum_IntList_out4callERKNS_6TensorEN3c1016OptionalArrayRefIlEEbSt8optionalINS510ScalarTypeEERS2" is a mangled CPP function name, to demangle it use this Demangler tool the function is at::_ops::sum_IntList_out::call(at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType>, at::Tensor&) So what happens is that flash-attn built with a pytorch version that does not align, don't know which version should..

lucapericlp commented 1 month ago

@tsvisab, not sure if you've resolved this but for anyone who might come across this but installing from source solved it for me.