NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.6k stars 255 forks source link

How to cast 16/32-bit to FP8? #965

Open mxjmtxrm opened 1 week ago

mxjmtxrm commented 1 week ago

Hi, how to cast a float/bfloat16 tensor to fp8? I want to conduct W8A8 (fp8) quantization. But I didn't find an example of quantizing act to FP8 format.

timmoon10 commented 1 week ago

The easiest approach is to use native PyTorch FP8 dtypes:

x = torch.randn(128, device="cuda", dtype=torch.float32)
y = x.to(dtype=torch.float8_e4m3fn)  # or torch.float8_e5m2

You could also use transformer_engine.pytorch.Float8Tensor / float8_experimental.Float8Tensor:

scale = torch.ones(1, device="cuda", dtype=torch.float32)
y1 = te.Float8Tensor.to_float8(x)
y2 = float8_experimental.Float8Tensor.to_float8(x, scale, torch.float8_e4m3fn)

These classes are based on each other and they have some nice convenience features (support for scaling factors, casting to higher precision for ops that don't support FP8, float8_experimental has torch.compile support).

Finally, you could directly use the FP8 kernels from Transformer Engine:

y = te.cpp_extensions.cast_to_fp8(
    x,
    fp8_meta,
    0,
    transformer_engine_torch.DType.kFloat8E4M3,
)

I strongly advise against using these internal functions though. Their APIs are unstable, messy, and tightly integrated with TE's logic for computing FP8 scaling factors.

mxjmtxrm commented 1 week ago

Thanks @timmoon10. How to do mixed-precision calculations? matrix multiplication of FP8 and FP16 tensors to get FP16 output.

timmoon10 commented 6 days ago

If you just want the performance benefit of FP8 matmuls, I recommend using Transformer Engine modules (like te.Linear) in your model (see this FP8 tutorial). They will internally handle the FP8 casts and FP8 scaling factors.

If you want more control, you'll have to get a bit into the weeds. I'm not sure if native PyTorch FP8 tensors support matmuls (even if they did, there would be numerical issues without FP8 scaling factors), but I see that float8_experimental.Float8Tensor does support matmuls with scaling factors (see addmm_float8_unwrapped). As far as I can tell, this just ends up calling cuBLAS (see scaled_gemm). Be advised that cuBLAS only supports FP8 inputs (see the FP8 support matrix for cublasLtMatmul). Implementing a custom matmul kernel with support for mixed FP8 and FP16 inputs may be possible using CUTLASS, but would get quite involved (and probably still be slower than TE for end-to-end training).