Open kuhar opened 9 months ago
This is just an umbrella issue to get started. Feel free to modify / fill in the blanks / link sub-issues and related discussions. cc: @antiagainst @MaheshRavishankar @qedawkins @raikonenfnu @hanhanW @bjacob
The gfx940 ISA supports 2 fp8 formats: fp8 and bf8. You can see both format supported with mfma, including operands of mixed formats: https://llvm.org/docs/AMDGPU/AMDGPUAsmGFX940.html#vop3.
FP8 mfma is plumbed through the amdgpu llvm backend: https://reviews.llvm.org/D129906, for example:
// CHECK-GFX940-LABEL: @test_mfma_f32_32x32x16_fp8_bf8
// CHECK-GFX940: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.fp8.bf8(i64 %a, i64 %b, <16 x float> %c, i32 0, i32 0, i32 0)
void test_mfma_f32_32x32x16_fp8_bf8(global v16f* out, long a, long b, v16f c)
{
*out = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(a, b, c, 0, 0, 0);
}
The fp8 operands are packed as i64. The only other amdgcn intrinsic for fp8 types is cvt
-- type conversions. https://github.com/llvm/llvm-project/blob/cd3942059eed7b7185f26bc583ac287a995db0d0/clang/include/clang/Basic/BuiltinsAMDGPU.def#L400-L407
FP8 is E4M3 (inference-focused) while BF8 is E5M2 (training-focused): https://www.amd.com/en/products/accelerators/instinct/mi300/mi300a.html.
OCP 8-bit Floating Point Specification (OFP8)
Related paper with an overview of fp8 types: FP8 FORMATS FOR DEEP LEARNING
Related blog post with overview of fp8 support for H100: https://lambdalabs.com/blog/nvidia-hopper-h100-and-fp8-support
FP8 support in LLVM/MLIR:
RFC from Sep '22 by @stellaraccident: https://discourse.llvm.org/t/rfc-add-apfloat-and-mlir-type-support-for-fp8-e5m2/65279.
f8E5M2
Since then, the other types plumbed all the way through MLIR are:
Float8E4M3FNType f8E4M3FNTy;
Float8E5M2FNUZType f8E5M2FNUZTy;
Float8E4M3FNUZType f8E4M3FNUZTy;
Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
.Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
.Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
.Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
.Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
.Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
func.func @float_attrs_pass() {
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E5M2
float_attr = 2. : f8E5M2
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3FN
float_attr = 2. : f8E4M3FN
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E5M2FNUZ
float_attr = 2. : f8E5M2FNUZ
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3FNUZ
float_attr = 2. : f8E4M3FNUZ
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3B11FNUZ
float_attr = 2. : f8E4M3B11FNUZ
} : () -> ()
"test.float_attrs
static constexpr fltSemantics semFloat8E5M2 = {15, -14, 3, 8};
static constexpr fltSemantics semFloat8E5M2FNUZ = {
15, -15, 3, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3FN = {
8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};
static constexpr fltSemantics semFloat8E4M3FNUZ = {
7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3B11FNUZ = {
4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
amgcn's fp8 maps to f8E4M3FNUZ
while bf8 to f8E5M2NUZ
.
FP8 is E4M3 (inference-focused) while BF8 is E5M2 (training-focused): https://www.amd.com/en/products/accelerators/instinct/mi300/mi300a.html.
If a model is trained with 2-bit mantissas (E5M2), how is the 3rd bit of mantissa in E4M3 going to be useful in inference?
Also https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html talks a bit about fp8 in NVIDIA GPUs, which is useful reference.
In general, fp8 right now are just used in a very ad-hoc way--with ISAs just do conversion and tensor/matrix core ops. For training we also have different fp8 scaling factors for different tensors and need model/framework level handling there, so also quite ad-hoc.
So as we've discussed in the meeting, getting a minimal matmul to excersise fp8 + tensor/matrix core in IREE/SHARK would be good start and foundation to everything else. We can then build other parts on top.
FP8 is E4M3 (inference-focused) while BF8 is E5M2 (training-focused): https://www.amd.com/en/products/accelerators/instinct/mi300/mi300a.html.
If a model is trained with 2-bit mantissas (E5M2), how is the 3rd bit of mantissa in E4M3 going to be useful in inference?
This is explained a bit in the NVIDIA doc as linked in my previous comment:
During training neural networks both of these types may be utilized. Typically forward activations and weights require more precision, so E4M3 datatype is best used during forward pass. In the backward pass, however, gradients flowing through the network typically are less susceptible to the loss of precision, but require higher dynamic range. Therefore they are best stored using E5M2 data format. H100 TensorCores provide support for any combination of these types as the inputs, enabling us to store each tensor using its preferred precision.
Support in MLIR/LLVM/AMDGPU already seems quite promising, so as discussed this morning the plan is to show a very simple example using fp8 in IREE first, something like
module {
func.func @matmul_static(%arg0: tensor<32x32xi8>, %arg1: tensor<32x32xi8>, %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> {
%0 = tensor.bitcast %arg0 : tensor<32x32xi8> to tensor<32x32xf8E4M3FNUZ>
%1 = tensor.bitcast %arg1 : tensor<32x32xi8> to tensor<32x32xf8E4M3FNUZ>
%2 = linalg.matmul ins(%0, %1 : tensor<32x32xf8E4M3FNUZ>, tensor<32x32xf8E4M3FNUZ>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32>
return %2 : tensor<32x32xf32>
}
}
or, to avoid the need to also handle mfma at the same time, just something as simple as
#map = affine_map<(d0) -> (d0)>
module {
func.func @extend_i8(%arg0: tensor<32xi8>) -> tensor<32xf32> {
%0 = tensor.bitcast %arg0 : tensor<32xi8> to tensor<32xf8E4M3FNUZ>
%1 = tensor.empty() : tensor<32xf32>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%0 : tensor<32xf8E4M3FNUZ>) outs(%1 : tensor<32xf32>) {
^bb0(%in: f8E4M3FNUZ, %out: f32):
%3 = arith.extf %in : f8E4M3FNUZ to f32
linalg.yield %3 : f32
} -> tensor<32xf32>
return %2 : tensor<32xf32>
}
}
Explanation of the LLVM fp semantics naming convention:
F is for "finite" (no infinities), N for with special NaN encoding, UZ for unsigned zero.
source: https://github.com/jax-ml/ml_dtypes?tab=readme-ov-file#float8_e5m2fnuz
Looking through support in MLIR and lowering into NVVM/ROCDL, seems to be already there as well..
MFMA to ROCLD intrinsics :
Tensor core instructions lowering
So for the examples in this comment https://github.com/nod-ai/SHARK/issues/2054#issuecomment-1877842991 , the extension truncation should just pass through and compile on AMD. The mfma support, it would be great if we could just take a single matmul of the exact mfma shape and it would just lower to that operation. Like literally all tile sizes would be 1... it should vectorize to vector.contract
, lower to amdgpu.mfma
-> rocdl intrinsics...
This is an umbrella issue for allowing fp8 type(s) in shark, spanning all the required layers of the stack: Turbine, IREE, MLIR, LLVM, including backends of interest like ROCm.
Some initial research is required to scope this properly and divide into subtasks, but the main work items are roughly:
llvm::APFloat