Open jakevdp opened 2 years ago
FWIW, this does not reproduce on my macbook, albeit with a (locally built) jaxlib 0.3.18, also at HEAD:
>>> import jax.numpy as jnp
>>> import jax
>>> jax.config.update("jax_enable_x64", True)
>>> x = jnp.float64(2.718)
>>> x_f16 = x.astype("float16")
>>> print(x, x_f16)
2.718 2.719
>>> jax.print_environment_info()
jax: 0.3.18
jaxlib: 0.3.18
numpy: 1.23.3
python: 3.10.6 (main, Aug 30 2022, 04:58:14) [Clang 13.1.6 (clang-1316.0.21.2.5)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
It does reproduce on my MacBook, with jaxlib 0.3.15, and I think that's where the problem lies.
jax.print_environment_info()
# jax: 0.3.18
# jaxlib: 0.3.15
# numpy: 1.23.3
# python: 3.10.6 (main, Aug 30 2022, 05:12:36) [Clang 13.1.6 (clang-1316.0.21.2.5)]
# jax.devices (1 total, 1 local): [CpuDevice(id=0)]
# process_count: 1
This does not reproduce on my M1 Macbook pro with either jaxlib 0.3.15 or jaxlib from head.
For those of you where this fails, are you using Intel or ARM macbooks? Can you try building jaxlib from head and see if it reproduces? I'm speculating this is probably already fixed with an up to date jaxlib.
Mine is an intel macbook
Mine is also an intel MacBook. And the problem persists with a locally built jaxlib
from HEAD.
In [1]: import jax.numpy as jnp
...: import jax
...:
...: jax.config.update('jax_enable_x64', True)
...:
...: x = jnp.float64(2.718)
...: x_f16 = x.astype('float16')
...:
...: print(x, x_f16)
...: # 2.718 -15790.0
2.718 -15790.0
In [2]: jax.print_environment_info()
jax: 0.3.18
jaxlib: 0.3.18
numpy: 1.23.3
python: 3.10.6 (main, Aug 30 2022, 05:12:36) [Clang 13.1.6 (clang-1316.0.21.2.5)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
Could you share the details of what CPU you have? Sharing the output of sysctl -a | grep machdep.cpu
should do it.
@hawkinsp Here you go.
Could someone who experiences this problem please run the reproduction with the environment variable XLA_FLAGS=--xla_dump_to=/tmp/somewhere
and share a zip file or similar with the directory of files it produces?
My current best guess is that this may be related to the _Float16
ABI changing in LLVM.
This program in essence compiles into a single call to a builtin named __truncdfhf2
. The only explanation I can think of is that there is a mismatch of calling conventions for Mac x86-64.
Notably this changed in LLVM recently: https://reviews.llvm.org/D131172
LLVM changed the x86-64 fp16 abi from passing the value in integer registers to passing it in floating point registers. This means that __truncdfhf2
that's coming from the system is now incompatible with LLVM in jaxlib.
If I remember correctly __truncdfhf2
is provided by the system on macOS, not XCode. This would mean we're stuck on the wrong ABI until Apple decides to change it. Which seems unlikely to happen outside of a major release. I'm curious what will happen when they hit this change when updating Clang in XCode though.
This isn't an issue for f32->f16 because there's been a hardware instruction for it since Haswell, but there's none for f64->f16 (it exists in AVX512, but no mac was released with it).
I think we can work around this by making simple_orc_jit bind __truncdfhf2
to the fallback version in runtime_fp16.cc, which I fixed to use the correct ABI. I have no Intel mac around to test that.
Running from github HEAD:
version info:
I can only reproduce this locally on my non-M1 macbook; I've not been able to reproduce in Colab or on linux.