Open ayaka14732 opened 1 week ago
import functools import jax from jax.experimental import pallas as pl import jax.numpy as jnp @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.bool), # interpret=True, ) def kernel(x_ref, y_ref, o_ref): o_ref[...] = x_ref[...] == y_ref[...] def main(): x = jnp.array([True, True], dtype=jnp.bool) y = jnp.array([False, False], dtype=jnp.bool) out = kernel(x, y) print(out) if __name__ == '__main__': main()
Error:
https://symbolize.stripped_domain/r/?trace=7fbe9de8075b,7fbf4944251f,7fbe9dd8b8fd,7fbe9dd38ee7,7fbe9dd38e7d,7fbe9dd3798c,7fbe9dec8953,7fbe9dec818a,7fbe9d64981d,7fbe9ecc263b,7fbe9ed16b27,7fbe9ec2d81e,5794dc&map= *** SIGSEGV (@0x18), see go/stacktraces#s15 received by PID 361853 (TID 361853) on cpu 131; stack trace: *** PC: @ 0x7fbe9de8075b (unknown) mlir::ShapedType::getShape() @ 0x7fbe98eb5be1 1888 (unknown) @ 0x7fbf49442520 103807888 (unknown) @ 0x7fbe9dd8b8fe 352 mlir::arith::SelectOp::verifyInvariantsImpl() @ 0x7fbe9dd38ee8 32 mlir::op_definition_impl::verifyTraits<>() @ 0x7fbe9dd38e7e 32 mlir::Op<>::verifyInvariants() @ 0x7fbe9dd3798d 80 mlir::RegisteredOperationName::Model<>::verifyInvariants() @ 0x7fbe9dec8954 768 (anonymous namespace)::OperationVerifier::verifyOpAndDominance() @ 0x7fbe9dec818b 32 mlir::verify() @ 0x7fbe9d64981e 16 mlirOperationVerify @ 0x7fbe9ecc263c 192 mlir::python::PyOperationBase::verify() @ 0x7fbe9ed16b28 64 pybind11::cpp_function::initialize<>()::{lambda()#1}::__invoke() @ 0x7fbe9ec2d81f 560 pybind11::cpp_function::dispatcher() @ 0x5794dd (unknown) (unknown) @ ... and at least 1 more frames https://symbolize.stripped_domain/r/?trace=7fbe9de8075b,7fbe98eb5be0,7fbf4944251f,7fbe9dd8b8fd,7fbe9dd38ee7,7fbe9dd38e7d,7fbe9dd3798c,7fbe9dec8953,7fbe9dec818a,7fbe9d64981d,7fbe9ecc263b,7fbe9ed16b27,7fbe9ec2d81e,5794dc&map= E0930 20:53:32.324661 361853 coredump_hook.cc:316] RAW: Remote crash data gathering hook invoked. E0930 20:53:32.324669 361853 coredump_hook.cc:355] RAW: Skipping coredump since rlimit was 0 at process start. E0930 20:53:32.324677 361853 client.cc:269] RAW: Coroner client retries enabled, will retry for up to 30 sec. E0930 20:53:32.324681 361853 coredump_hook.cc:411] RAW: Sending fingerprint to remote end. E0930 20:53:32.324696 361853 coredump_hook.cc:420] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory E0930 20:53:32.324702 361853 coredump_hook.cc:472] RAW: Dumping core locally. E0930 20:53:32.508649 361853 process_state.cc:805] RAW: Raising signal 11 with default behavior [1] 361853 segmentation fault (core dumped) python 2.py
jax: 0.4.34.dev20240924+85a466d73 jaxlib: 0.4.33 numpy: 2.1.0 python: 3.12.4 (main, Jun 8 2024, 18:29:57) [GCC 11.4.0] jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=6, process_index=0, coords=(2,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(3,1,0), core_on_chip=0)] process_count: 1 platform: uname_result(system='Linux', node='t1v-n-ab2ce832-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')
It's better to debug after https://github.com/jax-ml/jax/pull/24086
Description
Error:
System info (python version, jaxlib version, accelerator, etc.)