Open hawkinsp opened 1 week ago
The following JAX test crashes when compiled on a GCP c4a Axion ARM VM:
$ python tests/lax_test.py LaxTest.testConvGeneralDilatedLocal8 Running tests under Python 3.12.3: /home/phawkins/myenv/bin/python [ RUN ] LaxTest.testConvGeneralDilatedLocal8 (n=2, lhs_spec='NHCW', rhs_spec='OHWI', out_spec='HWCN', dtype=<class 'ml_dtypes.bfloat16'>, precision=Precision.HIGHEST, padding='VALID') LLVM ERROR: Cannot select: 0xe0d25c00bb70: nxv8bf16 = AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU 0xe0d25c152000, 0xe0d25c00bda0, undef:nxv8bf16 0xe0d25c152000: nxv8i1 = AArch64ISD::PTRUE TargetConstant:i32<31> 0xe0d25c152e00: i32 = TargetConstant<31> 0xe0d25c00bda0: nxv8i16,ch = masked_load<(load unknown-size from %ir.scevgep7, align 1, !noalias !4), zext from nxv8i8> 0xe0d25c0e8c40, 0xe0d25c00bb00, undef:i64, 0xe0d25c1520e0, undef:nxv8i16 0xe0d25c00bb00: i64 = add 0xe0d25c00b9b0, 0xe0d25c00ba90 0xe0d25c00b9b0: i64,ch = CopyFromReg 0xe0d25c0e8c40, Register:i64 %5 0xe0d25c152540: i64 = Register %5 0xe0d25c00ba90: i64,ch = CopyFromReg 0xe0d25c0e8c40, Register:i64 %7 0xe0d25c00ba20: i64 = Register %7 0xe0d25c152620: i64 = undef 0xe0d25c1520e0: nxv8i1,ch = CopyFromReg 0xe0d25c0e8c40, Register:nxv8i1 %8 0xe0d25c152d90: nxv8i1 = Register %8 0xe0d25c152c40: nxv8i16 = undef 0xe0d25c1522a0: nxv8bf16 = undef In function: convert.2 Fatal Python error: Aborted
I believe the crashing module is just this one, since it's the last thing the program dumps before aborting:
HloModule jit_convert_element_type, entry_computation_layout={(pred[6,6]{1,0})->bf16[6,6]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true} ENTRY main.3 { Arg_0.1 = pred[6,6]{1,0} parameter(0), metadata={op_name="args[0]"} ROOT convert.2 = bf16[6,6]{1,0} convert(Arg_0.1), metadata={op_name="jit(convert_element_type)/jit(main)/convert_element_type" source_file="/home/phawkins/jax/tests/lax_test.py" source_line=775} }
A c4a VM describes its CPU architecture this way:
$ lscpu Architecture: aarch64 CPU op-mode(s): 64-bit Byte Order: Little Endian CPU(s): 32 On-line CPU(s) list: 0-31 Vendor ID: ARM Model name: Neoverse-V2 Model: 1 Thread(s) per core: 1 Core(s) per socket: 32 Socket(s): 1 Stepping: r0p1 BogoMIPS: 2000.00 Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm ssbs sb paca pacg dcpodp sve2 sveaes svepmull svebitp erm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh rng bti Caches (sum of all): L1d: 2 MiB (32 instances) L1i: 2 MiB (32 instances) L2: 64 MiB (32 instances) L3: 80 MiB (1 instance) NUMA: NUMA node(s): 1 NUMA node0 CPU(s): 0-31 Vulnerabilities: Gather data sampling: Not affected Itlb multihit: Not affected L1tf: Not affected Mds: Not affected Meltdown: Not affected Mmio stale data: Not affected Reg file data sampling: Not affected Retbleed: Not affected Spec rstack overflow: Not affected Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Spectre v1: Mitigation; __user pointer sanitization Spectre v2: Not affected Srbds: Not affected Tsx async abort: Not affected
cc: @milpuz01
The following JAX test crashes when compiled on a GCP c4a Axion ARM VM:
I believe the crashing module is just this one, since it's the last thing the program dumps before aborting:
A c4a VM describes its CPU architecture this way: