NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.87k stars 308 forks source link

Bug in TransformerEngine v1.11 for PyTorch when using flash-attn>=2.5.7 #1236

Open saimidu opened 1 day ago

saimidu commented 1 day ago

Bug Description

Using transformer-engine[pytorch]==1.11 with flash-attn>=2.5.7 results in the following error:

>                   output = func(
                        query_layer,
                        key_layer,
                        value_layer,
                        *fa_optional_forward_args_thd,
                        self.attention_dropout if self.training else 0.0,
                        softmax_scale=self.softmax_scale,
                        causal="causal" in attn_mask_type,
                        **fa_optional_forward_kwargs,
                    )
E                   TypeError: flash_attn_func() got an unexpected keyword argument 'block_table'

/opt/conda/envs/pytorch/lib/python3.11/site-packages/transformer_engine/pytorch/attention.py:5073: TypeError

Details

The TransformerEngine v1.11 release introduces a change to the flash-attn version constraints in transformer_engine/pytorch/setup.py file: https://github.com/NVIDIA/TransformerEngine/blob/c27ee60ec746210bcea4ec33958dbbff06706506/transformer_engine/pytorch/setup.py#L59 from what it was in TransformerEngine <= v1.10: https://github.com/NVIDIA/TransformerEngine/blob/08a85d3b2657f1d4e0b478f6682c17fe6bba8b05/transformer_engine/pytorch/setup.py#L59

When installing transformer-engine[pytorch]==1.11 with no other constraints, it also installs flash-attn==2.6.3, which activates the code introduced in commit https://github.com/NVIDIA/TransformerEngine/commit/27c6342ea8ad88034bf04b587dd13cb6088d2474, where the block_table kwarg is configured: https://github.com/NVIDIA/TransformerEngine/blob/c27ee60ec746210bcea4ec33958dbbff06706506/transformer_engine/pytorch/attention.py#L5014-L5015

When using transformer_engine/pytorch/attention.py, the forward pass function used is decided here: https://github.com/NVIDIA/TransformerEngine/blob/c27ee60ec746210bcea4ec33958dbbff06706506/transformer_engine/pytorch/attention.py#L5017-L5024

If func is set to be flash_attn_varlen_func, which correctly consumes fa_optional_forward_kwargs["block_table"], the code (probably) works as intended. However, when func is set to flash_attn_func, which does not support the block_table arg, it results in the following error:

E                   TypeError: flash_attn_func() got an unexpected keyword argument 'block_table'

/opt/conda/envs/pytorch/lib/python3.11/site-packages/transformer_engine/pytorch/attention.py:5073: TypeError

How to reproduce

The bug can be reproduced by running pytest TransformerEngine/tests/pytorch/test_sanity.py::test_sanity_gpt: https://github.com/NVIDIA/TransformerEngine/blob/c27ee60ec746210bcea4ec33958dbbff06706506/tests/pytorch/test_sanity.py#L571

Workaround

Until the bug is fixed, the workaround seems to be to force the installation of "flash-attn<2.5.7" before installing transformer_engine. This bug has seemingly existed since v1.10, but has only become a point of failure after v1.11 allowed the installation of newer versions of flash-attn.

cyanguwa commented 11 hours ago

@ksivaman could we pick this PR to 1.11 release please? Thanks! https://github.com/NVIDIA/TransformerEngine/pull/1222

timmoon10 commented 9 hours ago

Just updated release_v1.11.