openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.75k stars 440 forks source link

//xla/tests:complex_unary_op_test_cpu fails on macOS Apple Silicon #19824

Open majnemer opened 5 days ago

majnemer commented 5 days ago

This test fails //xla/tests:complex_unary_op_test_cpu on my Mac mini.

I hacked the test up a bit to simplify things.

[ RUN      ] ComplexUnaryOpTest.Log1pTest
2024-11-24 17:54:52.487071: I xla/tests/literal_test_util.cc:56] expected: c64[2] c64[2] {(inf, -2.3561945), (-0.5, -0.5)}
2024-11-24 17:54:52.487089: I xla/tests/literal_test_util.cc:58] actual:   c64[2] c64[2] {(inf, -2.3561945), (-0.5, -0)}

This is using the following input:

 { { -min, -min }, { -min, -min }, 0x1p+125f }

After a bit of further investigation, I was able to determine that the root cause is that the implementation of atan2f depends on subnormal support even if the inputs and outputs are not subnormal.

#include <cmath>
#include <cfloat>
#include <cstdio>
#include <fenv.h>

static void doit() {
  volatile const float a = -FLT_MIN;
  volatile const float b = 1.0f;
  volatile float c = atan2f(a, b);
  c *= 0x1p+125f;
  printf("%.19g\n", c);
}

int main (int argc, char *argv[]) {
  doit();
  fesetenv(FE_DFL_DISABLE_DENORMS_ENV);
  doit();
  return 0;
}

On Darwin Kernel Version 24.1.0: Thu Nov 14 18:15:21 PST 2024; root:xnu-11215.41.3~13/RELEASE_ARM64_T6041, this outputs:

-0.5
-0
majnemer commented 5 days ago

OK, so here is what happens.

atan2f ends up doing its computation in double precision. This computation ends up with a result of:

(lldb) reg read -f"float64" d0
      d0 = {-1.1754943508222874E-38}

This is just under -1.17549435082228750797e-38 so this will end up getting flushed to 0 when it gets converted from double precision to single precision if subnormals are disabled.

majnemer commented 5 days ago

@pearu, can you please take a look?

pearu commented 5 days ago

Sure, this issue sounds very similar to the data point in https://github.com/jax-ml/jax/issues/24787#issuecomment-2501337976 , that is, on a Mac platform operations on smallest normal value lead to flushing to zero while on other platforms this does not happen. A simple fix is to adjust the test samples as in https://github.com/jax-ml/jax/pull/25117 (use nextafter(min, 1) instead of min), otherwise, eliminating these corner case differences between Mac and Linux platforms may be a complicated task.