jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.15k stars 2.76k forks source link

[Pallas TPU] Core dump when comparing two boolean arrays #24030

Open ayaka14732 opened 1 week ago

ayaka14732 commented 1 week ago

Description

import functools
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

@functools.partial(
    pl.pallas_call,
    out_shape=jax.ShapeDtypeStruct((2,), jnp.bool),
    # interpret=True,
)
def kernel(x_ref, y_ref, o_ref):
    o_ref[...] = x_ref[...] == y_ref[...]

def main():
    x = jnp.array([True, True], dtype=jnp.bool)
    y = jnp.array([False, False], dtype=jnp.bool)
    out = kernel(x, y)
    print(out)

if __name__ == '__main__':
    main()

Error:

https://symbolize.stripped_domain/r/?trace=7fbe9de8075b,7fbf4944251f,7fbe9dd8b8fd,7fbe9dd38ee7,7fbe9dd38e7d,7fbe9dd3798c,7fbe9dec8953,7fbe9dec818a,7fbe9d64981d,7fbe9ecc263b,7fbe9ed16b27,7fbe9ec2d81e,5794dc&map= 
*** SIGSEGV (@0x18), see go/stacktraces#s15 received by PID 361853 (TID 361853) on cpu 131; stack trace: ***
PC: @     0x7fbe9de8075b  (unknown)  mlir::ShapedType::getShape()
    @     0x7fbe98eb5be1       1888  (unknown)
    @     0x7fbf49442520  103807888  (unknown)
    @     0x7fbe9dd8b8fe        352  mlir::arith::SelectOp::verifyInvariantsImpl()
    @     0x7fbe9dd38ee8         32  mlir::op_definition_impl::verifyTraits<>()
    @     0x7fbe9dd38e7e         32  mlir::Op<>::verifyInvariants()
    @     0x7fbe9dd3798d         80  mlir::RegisteredOperationName::Model<>::verifyInvariants()
    @     0x7fbe9dec8954        768  (anonymous namespace)::OperationVerifier::verifyOpAndDominance()
    @     0x7fbe9dec818b         32  mlir::verify()
    @     0x7fbe9d64981e         16  mlirOperationVerify
    @     0x7fbe9ecc263c        192  mlir::python::PyOperationBase::verify()
    @     0x7fbe9ed16b28         64  pybind11::cpp_function::initialize<>()::{lambda()#1}::__invoke()
    @     0x7fbe9ec2d81f        560  pybind11::cpp_function::dispatcher()
    @           0x5794dd  (unknown)  (unknown)
    @ ... and at least 1 more frames
https://symbolize.stripped_domain/r/?trace=7fbe9de8075b,7fbe98eb5be0,7fbf4944251f,7fbe9dd8b8fd,7fbe9dd38ee7,7fbe9dd38e7d,7fbe9dd3798c,7fbe9dec8953,7fbe9dec818a,7fbe9d64981d,7fbe9ecc263b,7fbe9ed16b27,7fbe9ec2d81e,5794dc&map= 
E0930 20:53:32.324661  361853 coredump_hook.cc:316] RAW: Remote crash data gathering hook invoked.
E0930 20:53:32.324669  361853 coredump_hook.cc:355] RAW: Skipping coredump since rlimit was 0 at process start.
E0930 20:53:32.324677  361853 client.cc:269] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0930 20:53:32.324681  361853 coredump_hook.cc:411] RAW: Sending fingerprint to remote end.
E0930 20:53:32.324696  361853 coredump_hook.cc:420] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0930 20:53:32.324702  361853 coredump_hook.cc:472] RAW: Dumping core locally.
E0930 20:53:32.508649  361853 process_state.cc:805] RAW: Raising signal 11 with default behavior
[1]    361853 segmentation fault (core dumped)  python 2.py

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34.dev20240924+85a466d73
jaxlib: 0.4.33
numpy:  2.1.0
python: 3.12.4 (main, Jun  8 2024, 18:29:57) [GCC 11.4.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=6, process_index=0, coords=(2,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(3,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-ab2ce832-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')
ayaka14732 commented 14 hours ago

It's better to debug after https://github.com/jax-ml/jax/pull/24086