google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.09k stars 2.66k forks source link

Cryptic error message when running pallas kernel on TPU #21488

Open finbarrtimbers opened 1 month ago

finbarrtimbers commented 1 month ago

Description

I have a pallas kernel that's dying with a cryptic CHECK failure:

Check failed: 4 <= bitwidth (4 vs. signed char value 1)

I don't have a useful stack trace here- it's just the standard Google3 CHECK failure stuff that's pretty useless (see screenshot).

It happens in my call to pallas_call on a v5p-8. It traces my kernel fine but then dies before returning from pallas_call, which I assume means this is some sort of compilation error? It happens when interpret is set to both true and false.

The code runs without issue on CPU.

The error does not come up when I remove the custom_vjp decorator and only run the forward pass.

I have a repro on gist.

System info (python version, jaxlib version, accelerator, etc.)

System info for CPU env (where the code succeeds):

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.11 (main, May 23 2023, 13:58:30) [GCC 10.2.1 20210110]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='62b391e8e285', release='5.10.0-29-cloud-amd64', version='#1 SMP Debian 5.10.216-1 (2024-05-03)', machine='x86_64')

System info for TPU env (where the code fails):

jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.23.5
python: 3.10.11 (main, May 23 2023, 13:58:30) [GCC 10.2.1 20210110]
jax.devices (4 total, 4 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0) TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='gke-tpu-306bac09-cmh9', release='6.1.75+', version='#1 SMP PREEMPT_DYNAMIC Sat Mar 30 14:38:17 UTC 2024', machine='x86_64')
apaszke commented 1 month ago

My guess is that it's the mask operand that is causing the problem. If I understand correctly, it's an array of jnp.bool, but this type should not be supported at the kernel boundary. I think the best workaround would be to cast the mask to int8/int32 and then do mask != 0 inside the kernel to recover it in the boolean form.

Ultimately there are two things to fix here: (1) make Pallas more picky about input operand types and (2) add support for passing booleans to kernels.

finbarrtimbers commented 1 month ago

Ah, yes, that fixed it. Thanks!

Is it possible to have this be checked on CPU as well? I find it really confusing when there's such different behaviour between CPU and TPU.

On Thu, May 30 2024 at 02:24, Adam Paszke @.***> wrote:

My guess is that it's the mask operand that is causing the problem. If I understand correctly, it's an array of jnp.bool, but this type should not be supported at the kernel boundary. I think the best workaround would be to cast the mask to int8/int32 and then do mask != 0 inside the kernel to recover it in the boolean form.

Ultimately there are two things to fix here: (1) make Pallas more picky about input operand types and (2) add support for passing booleans to kernels.

— Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/21488#issuecomment-2138993319, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAYN6RKLSXBFXA2QSIA4I5DZE3O4NAVCNFSM6AAAAABIPNDNLSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMZYHE4TGMZRHE . You are receiving this because you authored the thread.Message ID: @.***>

apaszke commented 1 month ago

I think on CPU we only simulate the custom call, which you can get on a TPU if you pass in interpret=True. But note that then you won't be actually generating a kernel but it will expand to a soup of loopy HLOs. Either way, we should fix bool support.

finbarrtimbers commented 1 month ago

Ah, ty, I misunderstood and thought that the error was still happening with interpret=True, you're right, it doesn't.

On Fri, May 31, 2024 at 3:15 AM, Adam Paszke @.***> wrote:

I think on CPU we only simulate the custom call, which you can get on a TPU if you pass in interpret=True. But note that then you won't be actually generating a kernel but it will expand to a soup of loopy HLOs. Either way, we should fix bool support.

— Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/21488#issuecomment-2141569401, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAYN6RO52VV4SW6WRYVHHFDZFA5R5AVCNFSM6AAAAABIPNDNLSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNBRGU3DSNBQGE . You are receiving this because you authored the thread.Message ID: @.***>