openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.7k stars 434 forks source link

[cpu] LLVM error during compilation of bf16 convert on aarch64 #19105

Open hawkinsp opened 1 week ago

hawkinsp commented 1 week ago

The following JAX test crashes when compiled on a GCP c4a Axion ARM VM:

$ python tests/lax_test.py LaxTest.testConvGeneralDilatedLocal8
Running tests under Python 3.12.3: /home/phawkins/myenv/bin/python
[ RUN      ] LaxTest.testConvGeneralDilatedLocal8 (n=2, lhs_spec='NHCW', rhs_spec='OHWI', out_spec='HWCN', dtype=<class 'ml_dtypes.bfloat16'>, precision=Precision.HIGHEST, padding='VALID')
LLVM ERROR: Cannot select: 0xe0d25c00bb70: nxv8bf16 = AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU 0xe0d25c152000, 0xe0d25c00bda0, undef:nxv8bf16
  0xe0d25c152000: nxv8i1 = AArch64ISD::PTRUE TargetConstant:i32<31>
    0xe0d25c152e00: i32 = TargetConstant<31>
  0xe0d25c00bda0: nxv8i16,ch = masked_load<(load unknown-size from %ir.scevgep7, align 1, !noalias !4), zext from nxv8i8> 0xe0d25c0e8c40, 0xe0d25c00bb00, undef:i64, 0xe0d25c1520e0, undef:nxv8i16
    0xe0d25c00bb00: i64 = add 0xe0d25c00b9b0, 0xe0d25c00ba90
      0xe0d25c00b9b0: i64,ch = CopyFromReg 0xe0d25c0e8c40, Register:i64 %5
        0xe0d25c152540: i64 = Register %5
      0xe0d25c00ba90: i64,ch = CopyFromReg 0xe0d25c0e8c40, Register:i64 %7
        0xe0d25c00ba20: i64 = Register %7
    0xe0d25c152620: i64 = undef
    0xe0d25c1520e0: nxv8i1,ch = CopyFromReg 0xe0d25c0e8c40, Register:nxv8i1 %8
      0xe0d25c152d90: nxv8i1 = Register %8
    0xe0d25c152c40: nxv8i16 = undef
  0xe0d25c1522a0: nxv8bf16 = undef
In function: convert.2
Fatal Python error: Aborted

I believe the crashing module is just this one, since it's the last thing the program dumps before aborting:

HloModule jit_convert_element_type, entry_computation_layout={(pred[6,6]{1,0})->bf16[6,6]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}

ENTRY main.3 {
  Arg_0.1 = pred[6,6]{1,0} parameter(0), metadata={op_name="args[0]"}
  ROOT convert.2 = bf16[6,6]{1,0} convert(Arg_0.1), metadata={op_name="jit(convert_element_type)/jit(main)/convert_element_type" source_file="/home/phawkins/jax/tests/lax_test.py" source_line=775}
}

A c4a VM describes its CPU architecture this way:

$ lscpu
Architecture:             aarch64
  CPU op-mode(s):         64-bit
  Byte Order:             Little Endian
CPU(s):                   32
  On-line CPU(s) list:    0-31
Vendor ID:                ARM
  Model name:             Neoverse-V2
    Model:                1
    Thread(s) per core:   1
    Core(s) per socket:   32
    Socket(s):            1
    Stepping:             r0p1
    BogoMIPS:             2000.00
    Flags:                fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm ssbs sb paca pacg dcpodp sve2 sveaes svepmull svebitp
                          erm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh rng bti
Caches (sum of all):
  L1d:                    2 MiB (32 instances)
  L1i:                    2 MiB (32 instances)
  L2:                     64 MiB (32 instances)
  L3:                     80 MiB (1 instance)
NUMA:
  NUMA node(s):           1
  NUMA node0 CPU(s):      0-31
Vulnerabilities:
  Gather data sampling:   Not affected
  Itlb multihit:          Not affected
  L1tf:                   Not affected
  Mds:                    Not affected
  Meltdown:               Not affected
  Mmio stale data:        Not affected
  Reg file data sampling: Not affected
  Retbleed:               Not affected
  Spec rstack overflow:   Not affected
  Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
  Spectre v1:             Mitigation; __user pointer sanitization
  Spectre v2:             Not affected
  Srbds:                  Not affected
  Tsx async abort:        Not affected
penpornk commented 6 days ago

cc: @milpuz01