NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.85k stars 310 forks source link

[PyTorch] Check and set sliding window size based on attention mask type #895

Closed cyanguwa closed 3 months ago

cyanguwa commented 3 months ago

Description

This PR improves the function check_set_window_size in PyTorch so that window_size is always appropriately set based on the mask type.

no_mask, padding      : window_size = (-1, -1) or (>=0, >=0)
causal, padding_causal: window_size = (-1,  0)
arbitrary             : window_size = (-1, -1)

Type of change

Changes

Please list the changes introduced in this PR:

Checklist:

cyanguwa commented 3 months ago

/te-ci pytorch

cyanguwa commented 3 months ago

/te-ci pytorch

cyanguwa commented 3 months ago

/te-ci pytorch

cyanguwa commented 3 months ago

/te-ci pytorch