A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
If func is set to be flash_attn_varlen_func, which correctly consumes fa_optional_forward_kwargs["block_table"], the code (probably) works as intended. However, when func is set to flash_attn_func, which does not support the block_table arg, it results in the following error:
E TypeError: flash_attn_func() got an unexpected keyword argument 'block_table'
/opt/conda/envs/pytorch/lib/python3.11/site-packages/transformer_engine/pytorch/attention.py:5073: TypeError
Until the bug is fixed, the workaround seems to be to force the installation of "flash-attn<2.5.7" before installing transformer_engine. This bug has seemingly existed since v1.10, but has only become a point of failure after v1.11 allowed the installation of newer versions of flash-attn.
Bug Description
Using
transformer-engine[pytorch]==1.11
withflash-attn>=2.5.7
results in the following error:Details
The TransformerEngine v1.11 release introduces a change to the flash-attn version constraints in transformer_engine/pytorch/setup.py file: https://github.com/NVIDIA/TransformerEngine/blob/c27ee60ec746210bcea4ec33958dbbff06706506/transformer_engine/pytorch/setup.py#L59 from what it was in TransformerEngine <= v1.10: https://github.com/NVIDIA/TransformerEngine/blob/08a85d3b2657f1d4e0b478f6682c17fe6bba8b05/transformer_engine/pytorch/setup.py#L59
When installing
transformer-engine[pytorch]==1.11
with no other constraints, it also installsflash-attn==2.6.3
, which activates the code introduced in commit https://github.com/NVIDIA/TransformerEngine/commit/27c6342ea8ad88034bf04b587dd13cb6088d2474, where theblock_table
kwarg is configured: https://github.com/NVIDIA/TransformerEngine/blob/c27ee60ec746210bcea4ec33958dbbff06706506/transformer_engine/pytorch/attention.py#L5014-L5015When using transformer_engine/pytorch/attention.py, the forward pass function used is decided here: https://github.com/NVIDIA/TransformerEngine/blob/c27ee60ec746210bcea4ec33958dbbff06706506/transformer_engine/pytorch/attention.py#L5017-L5024
If
func
is set to beflash_attn_varlen_func
, which correctly consumesfa_optional_forward_kwargs["block_table"]
, the code (probably) works as intended. However, whenfunc
is set toflash_attn_func
, which does not support theblock_table
arg, it results in the following error:How to reproduce
The bug can be reproduced by running
pytest TransformerEngine/tests/pytorch/test_sanity.py::test_sanity_gpt
: https://github.com/NVIDIA/TransformerEngine/blob/c27ee60ec746210bcea4ec33958dbbff06706506/tests/pytorch/test_sanity.py#L571Workaround
Until the bug is fixed, the workaround seems to be to force the installation of
"flash-attn<2.5.7"
before installing transformer_engine. This bug has seemingly existed since v1.10, but has only become a point of failure after v1.11 allowed the installation of newer versions of flash-attn.