Open mxjmtxrm opened 1 week ago
The easiest approach is to use native PyTorch FP8 dtypes:
x = torch.randn(128, device="cuda", dtype=torch.float32)
y = x.to(dtype=torch.float8_e4m3fn) # or torch.float8_e5m2
You could also use transformer_engine.pytorch.Float8Tensor
/ float8_experimental.Float8Tensor
:
scale = torch.ones(1, device="cuda", dtype=torch.float32)
y1 = te.Float8Tensor.to_float8(x)
y2 = float8_experimental.Float8Tensor.to_float8(x, scale, torch.float8_e4m3fn)
These classes are based on each other and they have some nice convenience features (support for scaling factors, casting to higher precision for ops that don't support FP8, float8_experimental
has torch.compile
support).
Finally, you could directly use the FP8 kernels from Transformer Engine:
y = te.cpp_extensions.cast_to_fp8(
x,
fp8_meta,
0,
transformer_engine_torch.DType.kFloat8E4M3,
)
I strongly advise against using these internal functions though. Their APIs are unstable, messy, and tightly integrated with TE's logic for computing FP8 scaling factors.
Thanks @timmoon10. How to do mixed-precision calculations? matrix multiplication of FP8 and FP16 tensors to get FP16 output.
If you just want the performance benefit of FP8 matmuls, I recommend using Transformer Engine modules (like te.Linear
) in your model (see this FP8 tutorial). They will internally handle the FP8 casts and FP8 scaling factors.
If you want more control, you'll have to get a bit into the weeds. I'm not sure if native PyTorch FP8 tensors support matmuls (even if they did, there would be numerical issues without FP8 scaling factors), but I see that float8_experimental.Float8Tensor
does support matmuls with scaling factors (see addmm_float8_unwrapped
). As far as I can tell, this just ends up calling cuBLAS (see scaled_gemm
). Be advised that cuBLAS only supports FP8 inputs (see the FP8 support matrix for cublasLtMatmul
). Implementing a custom matmul kernel with support for mixed FP8 and FP16 inputs may be possible using CUTLASS, but would get quite involved (and probably still be slower than TE for end-to-end training).
Hi, how to cast a float/bfloat16 tensor to fp8? I want to conduct W8A8 (fp8) quantization. But I didn't find an example of quantizing act to FP8 format.