A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
This PR relocates the FP8 metadata for attention from FusedAttention to DotProductAttention. It makes DotProductAttention a TransformerEngineBaseModule and FusedAttention a torch.nn.module. In the future, core_attention._extra_state will be the centralized place for FP8 metadata for any attention backend, instead of core_attention.fused_attention._extra_state which was just for FusedAttention (originated from #768 ).
Type of change
[ ] Documentation change (change only to the documentation, either a fix or a new content)
[ ] Bug fix (non-breaking change which fixes an issue)
[ x] New feature (non-breaking change which adds functionality)
[ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
Changes
Please list the changes introduced in this PR:
subclassed DotProductAttention to TransformerEngineBaseModule
reverted FusedAttention to torch.nn.module
passed fp8_meta from DotProductAttention to FusedAttention
Description
This PR relocates the FP8 metadata for attention from
FusedAttention
toDotProductAttention
. It makesDotProductAttention
aTransformerEngineBaseModule
andFusedAttention
atorch.nn.module
. In the future,core_attention._extra_state
will be the centralized place for FP8 metadata for any attention backend, instead ofcore_attention.fused_attention._extra_state
which was just forFusedAttention
(originated from #768 ).Type of change
Changes
Please list the changes introduced in this PR:
DotProductAttention
toTransformerEngineBaseModule
FusedAttention
totorch.nn.module
fp8_meta
fromDotProductAttention
toFusedAttention
Checklist: