NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.6k stars 255 forks source link

[C/PyTorch] Add support for bottom-right-diagonal causal mask #960

Open cyanguwa opened 1 week ago

cyanguwa commented 1 week ago

Description

Currently, our causal masks are aligned to the top left corner of the softmax matrix, but in inference/KV caching, users often need to align them to the bottom right corner. This PR adds two mask types, causal_bottom_right and padding_causal_bottom_right, to support this new alignment. The old mask types, causal and padding_causal, are still used to denote the top-left alignment.

The new support matrix is,

attn_mask_type               |     supported backends
-------------------------------------------------------------------
no_mask                      |     All  
padding                      |     FlashAttention, FusedAttention
causal                       |    
    self-attention           |     All  
    cross-attention          |     FusedAttention
padding_causal               |    
    self-attention           |     FlashAttention, FusedAttention
    cross-attention          |     FusedAttention
causal_bottom_right          |     All  
padding_causal_bottom_right  |     FlashAttention, FusedAttention
arbitrary                    |     UnfusedDotProductAttention

This PR also extracts and streamlines the utility function get_attention_backend() for backend availability test. Users can call it with their model params/runtime environment to check which backends are available to support a particular set of user input, and which backend will be selected based on TransformerEngine's internal logic.

To facilitate the addition of bottom-right causal masks, this PR also makes two other changes for decoders (PR #895 ).

Type of change

Changes

Please list the changes introduced in this PR:

Checklist:

cyanguwa commented 1 day ago

/te-ci pytorch

cyanguwa commented 1 day ago

/te-ci pytorch

cyanguwa commented 15 hours ago

/te-ci pytorch

cyanguwa commented 9 hours ago

/te-ci pytorch