Open nathom opened 1 month ago
Pinging @sharadmv who will know best.
I think there is an out-of-bounds bug in the kernel that you wrote, which is hitting a runtime bounds check.
Specifically, the block spec for b
:
pl.BlockSpec((1, bl, 1, bn), lambda nc, l, i, j, k: (1, l, 1, j)),
should be:
pl.BlockSpec((1, bl, 1, bn), lambda nc, l, i, j, k: (0, l, 0, j)),
I think we could catch this error in interpret mode if we use checkify to look for OOB indexing.
Description
Hello, I'm running into a core dump when writing TPU kernels. I was testing with interpret on, and the kernel was working. Without it, I get a core dump. Any temporary fix is appreciated!
I ran with
When
interpret=True
, the assertion passes. Wheninterpret=False
, I get a core dump:System info (python version, jaxlib version, accelerator, etc.)