jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.27k stars 2.78k forks source link

Strange behavior of `convert_element_type` on main branch #12452

Open jakevdp opened 2 years ago

jakevdp commented 2 years ago

Running from github HEAD:

import jax.numpy as jnp
import jax

jax.config.update('jax_enable_x64', True)

x = jnp.float64(2.718)
x_f16 = x.astype('float16')

print(x, x_f16)
# 2.718 -15790.0

version info:

jax.print_environment_info()
# jax:    0.3.18
# jaxlib: 0.3.15
# numpy:  1.23.2
# python: 3.8.2 (v3.8.2:7b3ab5921f, Feb 24 2020, 17:52:18)  [Clang 6.0 (clang-600.0.57)]
# jax.devices (1 total, 1 local): [CpuDevice(id=0)]
# process_count: 1

I can only reproduce this locally on my non-M1 macbook; I've not been able to reproduce in Colab or on linux.

nicholasjng commented 2 years ago

FWIW, this does not reproduce on my macbook, albeit with a (locally built) jaxlib 0.3.18, also at HEAD:

>>> import jax.numpy as jnp
>>> import jax
>>> jax.config.update("jax_enable_x64", True)
>>> x = jnp.float64(2.718)
>>> x_f16 = x.astype("float16")
>>> print(x, x_f16)
2.718 2.719

>>> jax.print_environment_info()
jax:    0.3.18
jaxlib: 0.3.18
numpy:  1.23.3
python: 3.10.6 (main, Aug 30 2022, 04:58:14) [Clang 13.1.6 (clang-1316.0.21.2.5)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
soraros commented 2 years ago

It does reproduce on my MacBook, with jaxlib 0.3.15, and I think that's where the problem lies.

jax.print_environment_info()
# jax:    0.3.18
# jaxlib: 0.3.15
# numpy:  1.23.3
# python: 3.10.6 (main, Aug 30 2022, 05:12:36) [Clang 13.1.6 (clang-1316.0.21.2.5)]
# jax.devices (1 total, 1 local): [CpuDevice(id=0)]
# process_count: 1
hawkinsp commented 2 years ago

This does not reproduce on my M1 Macbook pro with either jaxlib 0.3.15 or jaxlib from head.

For those of you where this fails, are you using Intel or ARM macbooks? Can you try building jaxlib from head and see if it reproduces? I'm speculating this is probably already fixed with an up to date jaxlib.

jakevdp commented 2 years ago

Mine is an intel macbook

soraros commented 2 years ago

Mine is also an intel MacBook. And the problem persists with a locally built jaxlib from HEAD.

In [1]: import jax.numpy as jnp
   ...: import jax
   ...: 
   ...: jax.config.update('jax_enable_x64', True)
   ...: 
   ...: x = jnp.float64(2.718)
   ...: x_f16 = x.astype('float16')
   ...: 
   ...: print(x, x_f16)
   ...: # 2.718 -15790.0
2.718 -15790.0

In [2]: jax.print_environment_info()
jax:    0.3.18
jaxlib: 0.3.18
numpy:  1.23.3
python: 3.10.6 (main, Aug 30 2022, 05:12:36) [Clang 13.1.6 (clang-1316.0.21.2.5)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
hawkinsp commented 2 years ago

Could you share the details of what CPU you have? Sharing the output of sysctl -a | grep machdep.cpu should do it.

soraros commented 2 years ago

@hawkinsp Here you go.

``` machdep.cpu.mwait.linesize_min: 64 machdep.cpu.mwait.linesize_max: 64 machdep.cpu.mwait.extensions: 3 machdep.cpu.mwait.sub_Cstates: 286531872 machdep.cpu.thermal.sensor: 1 machdep.cpu.thermal.dynamic_acceleration: 1 machdep.cpu.thermal.invariant_APIC_timer: 1 machdep.cpu.thermal.thresholds: 2 machdep.cpu.thermal.ACNT_MCNT: 1 machdep.cpu.thermal.core_power_limits: 1 machdep.cpu.thermal.fine_grain_clock_mod: 1 machdep.cpu.thermal.package_thermal_intr: 1 machdep.cpu.thermal.hardware_feedback: 0 machdep.cpu.thermal.energy_policy: 1 machdep.cpu.xsave.extended_state: 31 832 1088 0 machdep.cpu.xsave.extended_state1: 15 832 256 0 machdep.cpu.arch_perf.version: 4 machdep.cpu.arch_perf.number: 4 machdep.cpu.arch_perf.width: 48 machdep.cpu.arch_perf.events_number: 7 machdep.cpu.arch_perf.events: 0 machdep.cpu.arch_perf.fixed_number: 3 machdep.cpu.arch_perf.fixed_width: 48 machdep.cpu.cache.linesize: 64 machdep.cpu.cache.L2_associativity: 4 machdep.cpu.cache.size: 256 machdep.cpu.tlb.inst.large: 8 machdep.cpu.tlb.data.small: 64 machdep.cpu.tlb.data.small_level1: 64 machdep.cpu.address_bits.physical: 39 machdep.cpu.address_bits.virtual: 48 machdep.cpu.tsc_ccc.numerator: 200 machdep.cpu.tsc_ccc.denominator: 2 machdep.cpu.max_basic: 22 machdep.cpu.max_ext: 2147483656 machdep.cpu.vendor: GenuineIntel machdep.cpu.brand_string: Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz machdep.cpu.family: 6 machdep.cpu.model: 158 machdep.cpu.extmodel: 9 machdep.cpu.extfamily: 0 machdep.cpu.stepping: 13 machdep.cpu.feature_bits: 9221959987971750911 machdep.cpu.leaf7_feature_bits: 43804591 1073741824 machdep.cpu.leaf7_feature_bits_edx: 3154120192 machdep.cpu.extfeature_bits: 1241984796928 machdep.cpu.signature: 591597 machdep.cpu.brand: 0 machdep.cpu.features: FPU VME DE PSE TSC MSR PAE MCE CX8 APIC SEP MTRR PGE MCA CMOV PAT PSE36 CLFSH DS ACPI MMX FXSR SSE SSE2 SS HTT TM PBE SSE3 PCLMULQDQ DTES64 MON DSCPL VMX EST TM2 SSSE3 FMA CX16 TPR PDCM SSE4.1 SSE4.2 x2APIC MOVBE POPCNT AES PCID XSAVE OSXSAVE SEGLIM64 TSCTMR AVX1.0 RDRAND F16C machdep.cpu.leaf7_features: RDWRFSGS TSC_THREAD_OFFSET SGX BMI1 AVX2 SMEP BMI2 ERMS INVPCID FPU_CSDS MPX RDSEED ADX SMAP CLFSOPT IPT SGXLC MDCLEAR IBRS STIBP L1DF ACAPMSR SSBD machdep.cpu.extfeatures: SYSCALL XD 1GBPAGE EM64T LAHF LZCNT PREFETCHW RDTSCP TSCI machdep.cpu.logical_per_package: 16 machdep.cpu.cores_per_package: 8 machdep.cpu.microcode_version: 240 machdep.cpu.processor_flag: 5 machdep.cpu.core_count: 8 machdep.cpu.thread_count: 16 ```
hawkinsp commented 2 years ago

Could someone who experiences this problem please run the reproduction with the environment variable XLA_FLAGS=--xla_dump_to=/tmp/somewhere and share a zip file or similar with the directory of files it produces?

soraros commented 2 years ago

@hawkinsp Here you go again. dump.zip

hawkinsp commented 2 years ago

My current best guess is that this may be related to the _Float16 ABI changing in LLVM.

This program in essence compiles into a single call to a builtin named __truncdfhf2. The only explanation I can think of is that there is a mismatch of calling conventions for Mac x86-64.

Notably this changed in LLVM recently: https://reviews.llvm.org/D131172

d0k commented 2 years ago

LLVM changed the x86-64 fp16 abi from passing the value in integer registers to passing it in floating point registers. This means that __truncdfhf2 that's coming from the system is now incompatible with LLVM in jaxlib.

If I remember correctly __truncdfhf2 is provided by the system on macOS, not XCode. This would mean we're stuck on the wrong ABI until Apple decides to change it. Which seems unlikely to happen outside of a major release. I'm curious what will happen when they hit this change when updating Clang in XCode though.

This isn't an issue for f32->f16 because there's been a hardware instruction for it since Haswell, but there's none for f64->f16 (it exists in AVX512, but no mac was released with it).

I think we can work around this by making simple_orc_jit bind __truncdfhf2 to the fallback version in runtime_fp16.cc, which I fixed to use the correct ABI. I have no Intel mac around to test that.