jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.59k stars 2.82k forks source link

paged_attention results in a `core dumped` on TPU-v3 #23825

Open OhadRubin opened 2 months ago

OhadRubin commented 2 months ago

Description

code to reproduce:

if __name__ == '__main__':
    # Set up example inputs
    batch_size = 8
    num_pages = 32
    page_size = 16
    head_dim = 128
    total_num_pages = 64*8
    num_kv_heads = num_heads = 4
    pages_per_compute_block=8
    pages_per_sequence=8
    rng = np.random.RandomState(42)
    xq = rng.randn(batch_size, num_heads, head_dim).astype(np.float32)
    k_pages = rng.randn(num_kv_heads, total_num_pages, page_size, head_dim).astype(np.float32)
    v_pages = rng.randn(num_kv_heads, total_num_pages, page_size, head_dim).astype(np.float32)
    lengths = rng.randint(1, 9, batch_size).astype(np.int32)
    page_indices = rng.randint(0, total_num_pages, (batch_size, pages_per_sequence)).astype(np.int32)
    xq, k_pages, v_pages, lengths, page_indices = jax.device_put((xq, k_pages, v_pages, lengths, page_indices))
    output =  paged_attention(xq, k_pages, v_pages, lengths, page_indices, pages_per_compute_block=pages_per_compute_block)
    print("Output shape:", output.shape) 
    print("Output:", output)

error log:

F0922 08:01:26.447775 2321927 ba16c7433_dma_descriptor_state.cc:189] Check failed: data_size.Unit() == b.target().DmaLengthGranule() (k512Byte vs. k1024Byte) 
*** Check failure stack trace: ***
    @     0x7f29e11fb184  (unknown)
    @     0x7f29e11facc8  (unknown)
    @     0x7f29e1426009  (unknown)
    @     0x7f29da77ec08  (unknown)
    @     0x7f29da7f120f  (unknown)
    @     0x7f29da7f019a  (unknown)
    @     0x7f29da7ef4d4  (unknown)
    @     0x7f29d721f6f4  (unknown)
    @     0x7f29d721e7d5  (unknown)
    @     0x7f29d721b599  (unknown)
    @     0x7f29d721dc52  (unknown)
    @     0x7f29d721b599  (unknown)
    @     0x7f29d721dc52  (unknown)
    @     0x7f29d721d5aa  (unknown)
    @     0x7f29d721b599  (unknown)
    @     0x7f29d71c52a3  (unknown)
    @     0x7f29d71c465b  (unknown)
    @     0x7f29d9a3e036  (unknown)
    @     0x7f29d9a3ab60  (unknown)
    @     0x7f29d9a37ae1  (unknown)
    @     0x7f29d9a33519  (unknown)
    @     0x7f29d71b4bed  (unknown)
    @     0x7f29d71b0268  (unknown)
    @     0x7f29d71a36d3  (unknown)
    @     0x7f29d718a258  (unknown)
    @     0x7f29d71a3a79  (unknown)
    @     0x7f29d71a8315  (unknown)
    @     0x7f29d71ab627  (unknown)
    @     0x7f29e0e0fa5e  (unknown)
    @     0x7f29e0e15d96  (unknown)
    @     0x7f29e0e1e9a5  (unknown)
    @     0x7f29e10c4db3  (unknown)
    @     0x7f2a92555609  start_thread
https://symbolize.stripped_domain/r/?trace=7f29e11fb184,7f29e11facc7,7f29e1426008,7f29da77ec07,7f29da7f120e,7f29da7f0199,7f29da7ef4d3,7f29d721f6f3,7f29d721e7d4,7f29d721b598,7f29d721dc51,7f29d721b598,7f29d721dc51,7f29d721d5a9,7f29d721b598,7f29d71c52a2,7f29d71c465a,7f29d9a3e035,7f29d9a3ab5f,7f29d9a37ae0,7f29d9a33518,7f29d71b4bec,7f29d71b0267,7f29d71a36d2,7f29d718a257,7f29d71a3a78,7f29d71a8314,7f29d71ab626,7f29e0e0fa5d,7f29e0e15d95,7f29e0e1e9a4,7f29e10c4db2,7f2a92555608&map= 
https://symbolize.stripped_domain/r/?trace=7f2a925b300b,7f2a925b308f,7f29e11fb1e8,7f29e11facc7,7f29e1426008,7f29da77ec07,7f29da7f120e,7f29da7f0199,7f29da7ef4d3,7f29d721f6f3,7f29d721e7d4,7f29d721b598,7f29d721dc51,7f29d721b598,7f29d721dc51,7f29d721d5a9,7f29d721b598,7f29d71c52a2,7f29d71c465a,7f29d9a3e035,7f29d9a3ab5f,7f29d9a37ae0,7f29d9a33518,7f29d71b4bec,7f29d71b0267,7f29d71a36d2,7f29d718a257,7f29d71a3a78,7f29d71a8314,7f29d71ab626,7f29e0e0fa5d,7f29e0e15d95,7f29e0e1e9a4&map= 
*** SIGABRT received by PID 2321018 (TID 2321927) on cpu 59 from PID 2321018; ***
E0922 08:01:26.484000 2321927 coredump_hook.cc:316] RAW: Remote crash data gathering hook invoked.
E0922 08:01:26.484016 2321927 coredump_hook.cc:355] RAW: Skipping coredump since rlimit was 0 at process start.
E0922 08:01:26.484021 2321927 client.cc:269] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0922 08:01:26.484025 2321927 coredump_hook.cc:411] RAW: Sending fingerprint to remote end.
E0922 08:01:26.484044 2321927 coredump_hook.cc:420] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0922 08:01:26.484048 2321927 coredump_hook.cc:472] RAW: Dumping core locally.
F0922 08:01:26.447775 2321927 ba16c7433_dma_descriptor_state.cc:189] Check failed: data_size.Unit() == b.target().DmaLengthGranule() (k512Byte vs. k1024Byte) 
E0922 08:01:26.738090 2321927 process_state.cc:805] RAW: Raising signal 6 with default behavior
Aborted (core dumped)

Note: this works fine on TPU-v4.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.33 jaxlib: 0.4.33 numpy: 1.26.4 python: 3.10.14 (main, Apr 6 2024, 18:45:05) [GCC 9.4.0] jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)] process_count: 1 platform: uname_result(system='Linux', node='v3-8-node-1', release='5.13.0-1027-gcp', version='#32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022', machine='x86_64')

justinjfu commented 2 months ago

This kernel uses DMAs which are unfortunately only supported for TPU v4 and up.

We should definitely improve the error message here since it's quite opaque.