openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.59k stars 403 forks source link

ElementalIrEmitterExecutionTest.IotaF8E4M3FN - Invalid LLVM IR #17323

Open apivovarov opened 4 days ago

apivovarov commented 4 days ago

I tried to add new test IotaF8E4M3FN based on IotaF8E4FNUZ code (xla/service/elemental_ir_emitter_test.cc)

New test failed on CPU (x86_64)

XLA_TEST_F(ElementalIrEmitterExecutionTest, IotaF8E4M3FN) {
  constexpr char hlo_text[] = R"(
  HloModule IotaF8E4M3FN
  ENTRY main {
    ROOT iota_ = f8e4m3fn[4] iota(), iota_dimension=0
  }
  )";

  RunTest(hlo_text, {});
}

Error:

[ RUN      ] ElementalIrEmitterExecutionTest.IotaF8E4FN
2024-09-18 20:44:24.550357: W ./xla/service/compiler.h:213] Ignoring the buffer assignment proto provided.
2024-09-18 20:44:24.552712: E xla/status_macros.cc:56] INTERNAL: RET_CHECK failure (xla/service/cpu/cpu_compiler.cc:954) !llvm::verifyModule(llvm_module, &err_stream) Invalid LLVM IR before optimizations:
UIToFP result must be FP or FP vector
  %5 = uitofp i64 %4 to i8

This probably indicates a bug in the HLO -> LLVM IR lowering. Rerun with --xla_dump_to to get the IR.
*** Begin stack trace ***
  tsl::CurrentStackTrace[abi:cxx11]()

  xla::status_macros::MakeErrorStream::Impl::GetStatus()

  xla::cpu::CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >)
  xla::cpu::CpuCompiler::RunBackend(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&)
  xla::HloRunner::CreateExecutableWithBufferAssignment(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >, xla::BufferAssignmentProto const*, bool)
  xla::HloRunner::ExecuteWithMovedDeviceBuffersAndBufferAssignment(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >, xla::BufferAssignmentProto const*, std::vector<xla::ScopedShapedBuffer, std::allocator<xla::ScopedShapedBuffer> >, bool, xla::ExecutionProfile*)
  xla::HloRunner::Execute(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >, absl::lts_20230802::Span<xla::Literal const* const>, bool, xla::ExecutionProfile*)
  xla::HloTestBase::RunAndCompareInternal(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >, absl::lts_20230802::Span<xla::Literal* const>, std::optional<xla::ErrorSpec> const&, bool, std::function<void (xla::HloModule*)> const&, std::function<void (xla::HloModule*)> const&)
  xla::HloTestBase::RunAndCompareNoHloPasses(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >, absl::lts_20230802::Span<xla::Literal* const>, std::optional<xla::ErrorSpec> const&, std::function<void (xla::HloModule*)> const&, std::function<void (xla::HloModule*)> const&)

  void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*)
  testing::Test::Run()
  testing::TestInfo::Run()
  testing::TestSuite::Run()
  testing::internal::UnitTestImpl::RunAllTests()
  bool testing::internal::HandleExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*)
  testing::UnitTest::Run()
  main
  __libc_start_main

*** End stack trace ***

xla/service/elemental_ir_emitter_test.cc:53: Failure
Value of: RunAndCompareNoHloPasses(std::move(module), args, nullopt)
  Actual: false (INTERNAL: RET_CHECK failure (xla/service/cpu/cpu_compiler.cc:954) !llvm::verifyModule(llvm_module, &err_stream) Invalid LLVM IR before optimizations:
UIToFP result must be FP or FP vector
  %5 = uitofp i64 %4 to i8

This probably indicates a bug in the HLO -> LLVM IR lowering. Rerun with --xla_dump_to to get the IR. )
Expected: true
[  FAILED  ] ElementalIrEmitterExecutionTest.IotaF8E4FN (17 ms)

Similar issue exist for f8e5m2 and f8e4m3b11fnuz types.

@reedwm

apivovarov commented 3 days ago

The test uses HloTestBase::RunAndCompareNoHloPasses to execute hlo_text on both the CPU and the interpreter, comparing the results.

Since it utilizes the NoHloPasses variant of RunAndCompare, all HLO passes, including float-normalization-<type> passes, are disabled. As a result, the CPU cannot directly execute the f8e4m3fn[4] iota() operation on an x86_64 CPU because this type is not natively supported.

Workaround: Run the test with the XLA flag --xla_enable_hlo_passes_only="float-normalization-<f8type>" to enable float-normalization pass. This pass rewrites the iota operation to use the f16 data type and then inserts a convert operation to convert the result from f16 to f8.

However, I am still unclear on how the test functions for f8e4m3fnuz and bf16, as these data types are also not natively supported on the x86_64 CPU.

Steps to Reproduce:

Command to run (executes HLO on the CPU with all HLO passes disabled):

./bazel-bin/xla/tools/run_hlo_module \
--input_format=hlo --platform=CPU a.hlo \
--print_literals --xla_disable_all_hlo_passes=true

HLO with three different iota operations—only the first two work, while the last one fails:

HloModule module
ENTRY main {
    ROOT iota_ = bf16[4] iota(), iota_dimension=0  // Works
    // ROOT iota_ = f8e4m3fnuz[4] iota(), iota_dimension=0 // Works
    // ROOT iota_ = f8e4m3fn[4] iota(), iota_dimension=0  // Failed
}

@reedwm , could you help clarify why there is a discrepancy between different types and the iota operation when HLO passes are disabled?

reedwm commented 3 days ago

It's expected that certain types, like FP8 types, don't work with arbitrary ops, which is why float-normalization upcasts them. My guess is that BF16 is supported as it's an LLVM IR type, according to this table. And F8E4M3FNUZ has some special logic here to handle it, although I'm not sure why we need to support this at all when float-normalization can handle it.

Anyway, I don't think there's anything we need to do for this issue, but let me know if you disagree. I don't think we need to support iota for FP8 types since float-normalization can handle it.

apivovarov commented 2 days ago

Hi Jake,

Could you please explain why we need special handling for F8E4M3FNUZ and F8E5M2FNUZ in the context of iota and conversion to F16? You can find the relevant code here: elemental_ir_emitter.cc#L3243-L3257

          if (component_element_type == F8E4M3FNUZ) {
            float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_);
          } else if (component_element_type == F8E5M2FNUZ) {
            float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_);
          } else {
            float_ir_type =
                llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
          }

          if (component_element_type == F8E4M3FNUZ ||
              component_element_type == F8E5M2FNUZ) {
            TF_ASSIGN_OR_RETURN(
                iota_result, EmitFloatingToF8fnuz(F16, float_val,
                                                  component_element_type, b_));
          } else {
            iota_result = float_val;
          }

It seems that float-normalization-<type> pass should handle this in a unified manner for all non-native CPU types.

Thank you!

@jakeh-gc