NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.86k stars 308 forks source link

[Feature Request] Use fp8 tensor cores for computation in attention implementation #266

Open DD-DuDa opened 1 year ago

DD-DuDa commented 1 year ago

I noticed that in the implementation of attention, torch.baddbmm and torch.bmm are used for Q @ K and QK @ V respectively. I wonder if fp8 tensor cores are not used for computation here and if the computation is performed in fp32 instead. Could you provide fp8_batch_gemm operation for this purpose?

MoFHeka commented 1 year ago

Same request