Open finbarrtimbers opened 1 month ago
My guess is that it's the mask
operand that is causing the problem. If I understand correctly, it's an array of jnp.bool
, but this type should not be supported at the kernel boundary. I think the best workaround would be to cast the mask to int8/int32 and then do mask != 0
inside the kernel to recover it in the boolean form.
Ultimately there are two things to fix here: (1) make Pallas more picky about input operand types and (2) add support for passing booleans to kernels.
Ah, yes, that fixed it. Thanks!
Is it possible to have this be checked on CPU as well? I find it really confusing when there's such different behaviour between CPU and TPU.
On Thu, May 30 2024 at 02:24, Adam Paszke @.***> wrote:
My guess is that it's the mask operand that is causing the problem. If I understand correctly, it's an array of jnp.bool, but this type should not be supported at the kernel boundary. I think the best workaround would be to cast the mask to int8/int32 and then do mask != 0 inside the kernel to recover it in the boolean form.
Ultimately there are two things to fix here: (1) make Pallas more picky about input operand types and (2) add support for passing booleans to kernels.
— Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/21488#issuecomment-2138993319, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAYN6RKLSXBFXA2QSIA4I5DZE3O4NAVCNFSM6AAAAABIPNDNLSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMZYHE4TGMZRHE . You are receiving this because you authored the thread.Message ID: @.***>
I think on CPU we only simulate the custom call, which you can get on a TPU if you pass in interpret=True
. But note that then you won't be actually generating a kernel but it will expand to a soup of loopy HLOs. Either way, we should fix bool support.
Ah, ty, I misunderstood and thought that the error was still happening with
interpret=True
, you're right, it doesn't.
On Fri, May 31, 2024 at 3:15 AM, Adam Paszke @.***> wrote:
I think on CPU we only simulate the custom call, which you can get on a TPU if you pass in interpret=True. But note that then you won't be actually generating a kernel but it will expand to a soup of loopy HLOs. Either way, we should fix bool support.
— Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/21488#issuecomment-2141569401, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAYN6RO52VV4SW6WRYVHHFDZFA5R5AVCNFSM6AAAAABIPNDNLSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNBRGU3DSNBQGE . You are receiving this because you authored the thread.Message ID: @.***>
Description
I have a pallas kernel that's dying with a cryptic CHECK failure:
I don't have a useful stack trace here- it's just the standard Google3 CHECK failure stuff that's pretty useless (see
).
It happens in my call to pallas_call on a v5p-8. It traces my kernel fine but then dies before returning from pallas_call, which I assume means this is some sort of compilation error? It happens when interpret is set to both true and false.
The code runs without issue on CPU.
The error does not come up when I remove the
custom_vjp
decorator and only run the forward pass.I have a repro on gist.
System info (python version, jaxlib version, accelerator, etc.)
System info for CPU env (where the code succeeds):
System info for TPU env (where the code fails):