After playing around with some inputs, the problem seems to occur for .at[:, i] and .at[:, :, i], but .at[i] works fine. So, any dimension higher than 0 seems to have a bug in the scatter update algorithm for booleans.
Is there some way I can help debug this? Is the jax-metal code open source? If it is, then if you point me to build instructions, I can try to track down the bug.
System info (python version, jaxlib version, accelerator, etc.)
Description
I ran into a rather surprising case involving 3D boolean arrays, which only seems to fail on jax-metal.
Correct behavior (CPU):
But on jax-metal, I get:
After playing around with some inputs, the problem seems to occur for
.at[:, i]
and.at[:, :, i]
, but.at[i]
works fine. So, any dimension higher than 0 seems to have a bug in the scatter update algorithm for booleans.Is there some way I can help debug this? Is the jax-metal code open source? If it is, then if you point me to build instructions, I can try to track down the bug.
System info (python version, jaxlib version, accelerator, etc.)
On jax-metal:
On CPU (correct behavior):