Open saimidu opened 1 month ago
@ksivaman could we pick this PR to 1.11 release please? Thanks! https://github.com/NVIDIA/TransformerEngine/pull/1222
Just updated release_v1.11
.
Thank you! Will you also be publishing a new patch version for v1.11 on pypi - https://pypi.org/project/transformer-engine/#history ?
Bug Description
Using
transformer-engine[pytorch]==1.11
withflash-attn>=2.5.7
results in the following error:Details
The TransformerEngine v1.11 release introduces a change to the flash-attn version constraints in transformer_engine/pytorch/setup.py file: https://github.com/NVIDIA/TransformerEngine/blob/c27ee60ec746210bcea4ec33958dbbff06706506/transformer_engine/pytorch/setup.py#L59 from what it was in TransformerEngine <= v1.10: https://github.com/NVIDIA/TransformerEngine/blob/08a85d3b2657f1d4e0b478f6682c17fe6bba8b05/transformer_engine/pytorch/setup.py#L59
When installing
transformer-engine[pytorch]==1.11
with no other constraints, it also installsflash-attn==2.6.3
, which activates the code introduced in commit https://github.com/NVIDIA/TransformerEngine/commit/27c6342ea8ad88034bf04b587dd13cb6088d2474, where theblock_table
kwarg is configured: https://github.com/NVIDIA/TransformerEngine/blob/c27ee60ec746210bcea4ec33958dbbff06706506/transformer_engine/pytorch/attention.py#L5014-L5015When using transformer_engine/pytorch/attention.py, the forward pass function used is decided here: https://github.com/NVIDIA/TransformerEngine/blob/c27ee60ec746210bcea4ec33958dbbff06706506/transformer_engine/pytorch/attention.py#L5017-L5024
If
func
is set to beflash_attn_varlen_func
, which correctly consumesfa_optional_forward_kwargs["block_table"]
, the code (probably) works as intended. However, whenfunc
is set toflash_attn_func
, which does not support theblock_table
arg, it results in the following error:How to reproduce
The bug can be reproduced by running
pytest TransformerEngine/tests/pytorch/test_sanity.py::test_sanity_gpt
: https://github.com/NVIDIA/TransformerEngine/blob/c27ee60ec746210bcea4ec33958dbbff06706506/tests/pytorch/test_sanity.py#L571Workaround
Until the bug is fixed, the workaround seems to be to force the installation of
"flash-attn<2.5.7"
before installing transformer_engine. This bug has seemingly existed since v1.10, but has only become a point of failure after v1.11 allowed the installation of newer versions of flash-attn.