jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.08k stars 2.75k forks source link

Pallas kernel crash at `llo::CouldLtS32` when `interpret=False` on TPU #23181

Open nathom opened 1 month ago

nathom commented 1 month ago

Description

Hello, I'm running into a core dump when writing TPU kernels. I was testing with interpret on, and the kernel was working. Without it, I get a core dump. Any temporary fix is appreciated!

def matmul_bias_gelu_kernel(x_ref, y_ref, b_ref, z_ref, acc_ref, *, nsteps):
    @pl.when(pl.program_id(axis=4) == 0)
    def _():
        acc_ref[...] = jnp.zeros_like(acc_ref)

    print(f"{ acc_ref[...].shape =}", f"{ x_ref[...].shape =}", f"{ y_ref[...].shape =}", f"{b_ref[...].shape = }")
    acc_ref[...] += jnp.matmul(
        x_ref[...].squeeze(axis=(0, 1)), y_ref[...].squeeze(axis=(0,)), preferred_element_type=jnp.float32
    ) + b_ref[...]

    @pl.when(pl.program_id(axis=4) == nsteps - 1)
    def _():
        z_ref[...] = gelu(acc_ref[...].astype(z_ref.dtype))

@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'bl', 'bnc'])
def matmul_bias_gelu(
    x: jax.Array,
    y: jax.Array,
    b: jax.Array,
    *,
    bl: int = 1,
    bm: int = 16,
    bk: int = 64,
    bn: int = 128,
    bnc: int = 1,
):
    """Compute gelu(x @ y + b)."""

    nc, l, m, k = x.shape
    l2, k2, n = y.shape
    one, l3, _one, n2 = b.shape
    assert l == l2 == l3 and k == k2 and one == 1 and _one == 1 and n == n2, f'Invalid dims {x.shape=} {y.shape=} {b.shape=}'
    assert l % bl == 0 and m % bm == 0 and k % bk == 0 and nc % bnc == 0, 'Block sizes must be multiples of dims'

    return pl.pallas_call(
        functools.partial(matmul_bias_gelu_kernel, nsteps=k // bk),
        grid_spec=pltpu.PrefetchScalarGridSpec(
            num_scalar_prefetch=0,
            in_specs=[
                pl.BlockSpec((bnc, bl, bm, bk), lambda nc, l, i, j, k: (nc, l, i, k)),
                pl.BlockSpec((bl, bk, bn), lambda nc, l, i, j, k: (l, k, j)),
                pl.BlockSpec((1, bl, 1, bn), lambda nc, l, i, j, k: (1, l, 1, j)),
            ],
            out_specs=pl.BlockSpec((bnc, bl, bm, bn), lambda nc, l, i, j, k: (nc, l, i, j)),
            scratch_shapes=[pltpu.VMEM((bnc, bl, bm, bn), jnp.float32)],
            grid=(nc // bnc, l // bl, m // bm, n // bn, k // bk),
        ),
        out_shape=jax.ShapeDtypeStruct((nc, l, m, n), x.dtype),
        compiler_params=dict(mosaic=dict(
            dimension_semantics=("parallel", "parallel", "parallel", "parallel", "arbitrary"))),
        # interpret=True, # doesn't work with interpret=False
    )(x, y, b)

I ran with

key = random.PRNGKey(42)
BS, NH, L, HF = 16, 32, 16 * 128, 64
CS = 16
NC = L // CS
HF_prime = 4 * HF

XB = random.normal(key, (NC, BS*NH, CS, HF))
W1 = random.normal(key, (BS*NH, HF, HF_prime))
b1 = random.normal(key, (1, BS*NH, 1, HF_prime))

output_ref = matmul_bias_gelu_ref(XB, W1, b1)
output_kernel = matmul_bias_gelu(XB, W1, b1)
print('max abs error', jnp.max(jnp.abs(output_ref - output_kernel)))
print(f"{output_ref.shape = } {output_kernel.shape = }")
assert jnp.allclose(output_ref, output_kernel, atol=1e-2)

When interpret=True, the assertion passes. When interpret=False, I get a core dump:

F0821 21:26:33.618965  402385 math_util.cc:68] Check failed: llo::CouldLtS32(digits[i], bounds[i]) 
*** Check failure stack trace: ***
    @     0x7f021a32ebe4  (unknown)
    @     0x7f021a32e6e8  (unknown)
    @     0x7f021a32f009  (unknown)
    @     0x7f021301eb4d  (unknown)
    @     0x7f02130122ec  (unknown)
    @     0x7f021301287b  (unknown)
    @     0x7f021300ad7e  (unknown)
    @     0x7f02130017f1  (unknown)
    @     0x7f0212ffe478  (unknown)
    @     0x7f02109c1908  (unknown)
    @     0x7f02109bd218  (unknown)
    @     0x7f02109b1a53  (unknown)
    @     0x7f0210999578  (unknown)
    @     0x7f02109b1dd9  (unknown)
    @     0x7f02109b61ce  (unknown)
    @     0x7f02109b9307  (unknown)
    @     0x7f0219f3499b  (unknown)
    @     0x7f0219f3b224  (unknown)
    @     0x7f0219f44045  (unknown)
    @     0x7f021a1fce53  (unknown)
    @     0x7f02cb494ac3  (unknown)
https://symbolize.stripped_domain/r/?trace=7f021a32ebe4,7f021a32e6e7,7f021a32f008,7f021301eb4c,7f02130122eb,7f021301287a,7f021300ad7d,7f02130017f0,7f0212ffe477,7f02109c1907,7f02109bd217,7f02109b1a52,7f0210999577,7f02109b1dd8,7f02109b61cd,7f02109b9306,7f0219f3499a,7f0219f3b223,7f0219f44044,7f021a1fce52,7f02cb494ac2&map= 
https://symbolize.stripped_domain/r/?trace=7f02cb4969fc,7f02cb44251f&map= 
*** SIGABRT received by PID 401472 (TID 402385) on cpu 12 from PID 401472; ***
E0821 21:26:33.654562  402385 coredump_hook.cc:316] RAW: Remote crash data gathering hook invoked.
E0821 21:26:33.654581  402385 coredump_hook.cc:355] RAW: Skipping coredump since rlimit was 0 at process start.
E0821 21:26:33.654589  402385 client.cc:269] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0821 21:26:33.654595  402385 coredump_hook.cc:411] RAW: Sending fingerprint to remote end.
E0821 21:26:33.654622  402385 coredump_hook.cc:420] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0821 21:26:33.654630  402385 coredump_hook.cc:472] RAW: Dumping core locally.
F0821 21:26:33.618965  402385 math_util.cc:68] Check failed: llo::CouldLtS32(digits[i], bounds[i]) 
E0821 21:26:33.900571  402385 process_state.cc:805] RAW: Raising signal 6 with default behavior
Aborted (core dumped)

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax
>>> jax.print_environment_info()
jax:    0.4.31
jaxlib: 0.4.31
numpy:  2.1.0
python: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-187c2405-w-0', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')
dfm commented 1 month ago

Pinging @sharadmv who will know best.

justinjfu commented 1 month ago

I think there is an out-of-bounds bug in the kernel that you wrote, which is hitting a runtime bounds check.

Specifically, the block spec for b:

pl.BlockSpec((1, bl, 1, bn), lambda nc, l, i, j, k: (1, l, 1, j)),

should be:

pl.BlockSpec((1, bl, 1, bn), lambda nc, l, i, j, k: (0, l, 0, j)),
sharadmv commented 1 month ago

I think we could catch this error in interpret mode if we use checkify to look for OOB indexing.