Open OhadRubin opened 2 months ago
code to reproduce:
if __name__ == '__main__': # Set up example inputs batch_size = 8 num_pages = 32 page_size = 16 head_dim = 128 total_num_pages = 64*8 num_kv_heads = num_heads = 4 pages_per_compute_block=8 pages_per_sequence=8 rng = np.random.RandomState(42) xq = rng.randn(batch_size, num_heads, head_dim).astype(np.float32) k_pages = rng.randn(num_kv_heads, total_num_pages, page_size, head_dim).astype(np.float32) v_pages = rng.randn(num_kv_heads, total_num_pages, page_size, head_dim).astype(np.float32) lengths = rng.randint(1, 9, batch_size).astype(np.int32) page_indices = rng.randint(0, total_num_pages, (batch_size, pages_per_sequence)).astype(np.int32) xq, k_pages, v_pages, lengths, page_indices = jax.device_put((xq, k_pages, v_pages, lengths, page_indices)) output = paged_attention(xq, k_pages, v_pages, lengths, page_indices, pages_per_compute_block=pages_per_compute_block) print("Output shape:", output.shape) print("Output:", output)
error log:
F0922 08:01:26.447775 2321927 ba16c7433_dma_descriptor_state.cc:189] Check failed: data_size.Unit() == b.target().DmaLengthGranule() (k512Byte vs. k1024Byte) *** Check failure stack trace: *** @ 0x7f29e11fb184 (unknown) @ 0x7f29e11facc8 (unknown) @ 0x7f29e1426009 (unknown) @ 0x7f29da77ec08 (unknown) @ 0x7f29da7f120f (unknown) @ 0x7f29da7f019a (unknown) @ 0x7f29da7ef4d4 (unknown) @ 0x7f29d721f6f4 (unknown) @ 0x7f29d721e7d5 (unknown) @ 0x7f29d721b599 (unknown) @ 0x7f29d721dc52 (unknown) @ 0x7f29d721b599 (unknown) @ 0x7f29d721dc52 (unknown) @ 0x7f29d721d5aa (unknown) @ 0x7f29d721b599 (unknown) @ 0x7f29d71c52a3 (unknown) @ 0x7f29d71c465b (unknown) @ 0x7f29d9a3e036 (unknown) @ 0x7f29d9a3ab60 (unknown) @ 0x7f29d9a37ae1 (unknown) @ 0x7f29d9a33519 (unknown) @ 0x7f29d71b4bed (unknown) @ 0x7f29d71b0268 (unknown) @ 0x7f29d71a36d3 (unknown) @ 0x7f29d718a258 (unknown) @ 0x7f29d71a3a79 (unknown) @ 0x7f29d71a8315 (unknown) @ 0x7f29d71ab627 (unknown) @ 0x7f29e0e0fa5e (unknown) @ 0x7f29e0e15d96 (unknown) @ 0x7f29e0e1e9a5 (unknown) @ 0x7f29e10c4db3 (unknown) @ 0x7f2a92555609 start_thread https://symbolize.stripped_domain/r/?trace=7f29e11fb184,7f29e11facc7,7f29e1426008,7f29da77ec07,7f29da7f120e,7f29da7f0199,7f29da7ef4d3,7f29d721f6f3,7f29d721e7d4,7f29d721b598,7f29d721dc51,7f29d721b598,7f29d721dc51,7f29d721d5a9,7f29d721b598,7f29d71c52a2,7f29d71c465a,7f29d9a3e035,7f29d9a3ab5f,7f29d9a37ae0,7f29d9a33518,7f29d71b4bec,7f29d71b0267,7f29d71a36d2,7f29d718a257,7f29d71a3a78,7f29d71a8314,7f29d71ab626,7f29e0e0fa5d,7f29e0e15d95,7f29e0e1e9a4,7f29e10c4db2,7f2a92555608&map= https://symbolize.stripped_domain/r/?trace=7f2a925b300b,7f2a925b308f,7f29e11fb1e8,7f29e11facc7,7f29e1426008,7f29da77ec07,7f29da7f120e,7f29da7f0199,7f29da7ef4d3,7f29d721f6f3,7f29d721e7d4,7f29d721b598,7f29d721dc51,7f29d721b598,7f29d721dc51,7f29d721d5a9,7f29d721b598,7f29d71c52a2,7f29d71c465a,7f29d9a3e035,7f29d9a3ab5f,7f29d9a37ae0,7f29d9a33518,7f29d71b4bec,7f29d71b0267,7f29d71a36d2,7f29d718a257,7f29d71a3a78,7f29d71a8314,7f29d71ab626,7f29e0e0fa5d,7f29e0e15d95,7f29e0e1e9a4&map= *** SIGABRT received by PID 2321018 (TID 2321927) on cpu 59 from PID 2321018; *** E0922 08:01:26.484000 2321927 coredump_hook.cc:316] RAW: Remote crash data gathering hook invoked. E0922 08:01:26.484016 2321927 coredump_hook.cc:355] RAW: Skipping coredump since rlimit was 0 at process start. E0922 08:01:26.484021 2321927 client.cc:269] RAW: Coroner client retries enabled, will retry for up to 30 sec. E0922 08:01:26.484025 2321927 coredump_hook.cc:411] RAW: Sending fingerprint to remote end. E0922 08:01:26.484044 2321927 coredump_hook.cc:420] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory E0922 08:01:26.484048 2321927 coredump_hook.cc:472] RAW: Dumping core locally. F0922 08:01:26.447775 2321927 ba16c7433_dma_descriptor_state.cc:189] Check failed: data_size.Unit() == b.target().DmaLengthGranule() (k512Byte vs. k1024Byte) E0922 08:01:26.738090 2321927 process_state.cc:805] RAW: Raising signal 6 with default behavior Aborted (core dumped)
Note: this works fine on TPU-v4.
jax: 0.4.33 jaxlib: 0.4.33 numpy: 1.26.4 python: 3.10.14 (main, Apr 6 2024, 18:45:05) [GCC 9.4.0] jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)] process_count: 1 platform: uname_result(system='Linux', node='v3-8-node-1', release='5.13.0-1027-gcp', version='#32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022', machine='x86_64')
This kernel uses DMAs which are unfortunately only supported for TPU v4 and up.
We should definitely improve the error message here since it's quite opaque.
Description
code to reproduce:
error log:
Note: this works fine on TPU-v4.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.33 jaxlib: 0.4.33 numpy: 1.26.4 python: 3.10.14 (main, Apr 6 2024, 18:45:05) [GCC 9.4.0] jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)] process_count: 1 platform: uname_result(system='Linux', node='v3-8-node-1', release='5.13.0-1027-gcp', version='#32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022', machine='x86_64')