A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
I noticed that in the implementation of attention, torch.baddbmm and torch.bmm are used for Q @ K and QK @ V respectively. I wonder if fp8 tensor cores are not used for computation here and if the computation is performed in fp32 instead. Could you provide fp8_batch_gemm operation for this purpose?
I noticed that in the implementation of attention, torch.baddbmm and torch.bmm are used for Q @ K and QK @ V respectively. I wonder if fp8 tensor cores are not used for computation here and if the computation is performed in fp32 instead. Could you provide fp8_batch_gemm operation for this purpose?