NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.85k stars 309 forks source link

feat(pytorch): Allow TransformerLayer and MultiheadAttention to accept sequence length parameters #1066

Closed hXl3s closed 1 month ago

hXl3s commented 2 months ago

Description

TransformerLayer and MultiheadAttention does not allow for passing arbitrary length sequences. While this feature is supported by DotProductAttention it cannot be controlled by higher abstraction layers. This PR fixes that issue.

Additionally fixes runtime bug when thd layout is used

Fixes # (issue)

Type of change

Changes

Please list the changes introduced in this PR:

Checklist:

ptrendx commented 2 months ago

Hi @hXl3s please sign your commits. See https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst#sign-your-work for details.

ptrendx commented 1 month ago

Could you also add an error when somebody tries to run THD in inference with KV-cache here: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L6732-L6735? @sudhakarsingh27 FYI, since this will touch the pieces you are looking at for merging THD inference support.

ptrendx commented 1 month ago

Also @hXl3s could you add some unit test to make sure it works (e.g. comparing thd vs bshd outputs of TransformerLayer)?

hXl3s commented 1 month ago

@ptrendx Added to testcase comparing output of THD vs BSHD for float16 and bfloat16. float32 is skipped as apparently cudnn or other implementation used does not support THD with float32

Additionally, as requested there is an assert now, that check if you do not use THD layout with KV-Cache during inference

hXl3s commented 1 month ago

Resolved all comments

ptrendx commented 1 month ago

/te-ci pytorch

ptrendx commented 1 month ago

@hXl3s Could you skip the thd test when GPU SM arch is lower than 8.0 (as neither FlashAttention nor cuDNN support those).

ptrendx commented 1 month ago

/te-ci pytorch