Closed hXl3s closed 1 month ago
Hi @hXl3s please sign your commits. See https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst#sign-your-work for details.
Could you also add an error when somebody tries to run THD in inference with KV-cache here: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L6732-L6735? @sudhakarsingh27 FYI, since this will touch the pieces you are looking at for merging THD inference support.
Also @hXl3s could you add some unit test to make sure it works (e.g. comparing thd vs bshd outputs of TransformerLayer)?
@ptrendx Added to testcase comparing output of THD vs BSHD for float16 and bfloat16. float32 is skipped as apparently cudnn or other implementation used does not support THD with float32
Additionally, as requested there is an assert now, that check if you do not use THD layout with KV-Cache during inference
Resolved all comments
/te-ci pytorch
@hXl3s Could you skip the thd test when GPU SM arch is lower than 8.0 (as neither FlashAttention nor cuDNN support those).
/te-ci pytorch
Description
TransformerLayer and MultiheadAttention does not allow for passing arbitrary length sequences. While this feature is supported by DotProductAttention it cannot be controlled by higher abstraction layers. This PR fixes that issue.
Additionally fixes runtime bug when thd layout is used
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
cu_seqlen_q
,cu_seqlens_kv
,max_seqlen_q
andmax_seqlen_kv
to MultiheadAttention and TransformerLayerChecklist: