Open apivovarov opened 2 months ago
The test uses HloTestBase::RunAndCompareNoHloPasses
to execute hlo_text on both the CPU and the interpreter, comparing the results.
Since it utilizes the NoHloPasses
variant of RunAndCompare
, all HLO passes, including float-normalization-<type>
passes, are disabled. As a result, the CPU cannot directly execute the f8e4m3fn[4] iota()
operation on an x86_64 CPU because this type is not natively supported.
Workaround: Run the test with the XLA flag --xla_enable_hlo_passes_only="float-normalization-<f8type>"
to enable float-normalization pass.
This pass rewrites the iota operation to use the f16 data type and then inserts a convert operation to convert the result from f16 to f8.
However, I am still unclear on how the test functions for f8e4m3fnuz
and bf16
, as these data types are also not natively supported on the x86_64 CPU.
Steps to Reproduce:
Command to run (executes HLO on the CPU with all HLO passes disabled):
./bazel-bin/xla/tools/run_hlo_module \
--input_format=hlo --platform=CPU a.hlo \
--print_literals --xla_disable_all_hlo_passes=true
HLO with three different iota operations—only the first two work, while the last one fails:
HloModule module
ENTRY main {
ROOT iota_ = bf16[4] iota(), iota_dimension=0 // Works
// ROOT iota_ = f8e4m3fnuz[4] iota(), iota_dimension=0 // Works
// ROOT iota_ = f8e4m3fn[4] iota(), iota_dimension=0 // Failed
}
@reedwm , could you help clarify why there is a discrepancy between different types and the iota operation when HLO passes are disabled?
It's expected that certain types, like FP8 types, don't work with arbitrary ops, which is why float-normalization upcasts them. My guess is that BF16 is supported as it's an LLVM IR type, according to this table. And F8E4M3FNUZ has some special logic here to handle it, although I'm not sure why we need to support this at all when float-normalization can handle it.
Anyway, I don't think there's anything we need to do for this issue, but let me know if you disagree. I don't think we need to support iota for FP8 types since float-normalization can handle it.
Hi Jake,
Could you please explain why we need special handling for F8E4M3FNUZ
and F8E5M2FNUZ
in the context of iota and conversion to F16? You can find the relevant code here: elemental_ir_emitter.cc#L3243-L3257
if (component_element_type == F8E4M3FNUZ) {
float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_);
} else if (component_element_type == F8E5M2FNUZ) {
float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_);
} else {
float_ir_type =
llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
}
if (component_element_type == F8E4M3FNUZ ||
component_element_type == F8E5M2FNUZ) {
TF_ASSIGN_OR_RETURN(
iota_result, EmitFloatingToF8fnuz(F16, float_val,
component_element_type, b_));
} else {
iota_result = float_val;
}
It seems that float-normalization-<type>
pass should handle this in a unified manner for all non-native CPU types.
Thank you!
@jakeh-gc
I tried to add new test
IotaF8E4M3FN
based onIotaF8E4FNUZ
code (xla/service/elemental_ir_emitter_test.cc)New test failed on CPU (x86_64)
Error:
Similar issue exist for
f8e5m2
andf8e4m3b11fnuz
types.@reedwm