Closed milutter closed 6 months ago
Hi @milutter
I executed the mentioned code with jax-metal 0.0.6 on a Macbook Pro with an M1 Pro chip to see if the reported issue persists. The code produces the same output regardless of using Just-In-Time (JIT) compilation.
Could you please verify with jax-metal 0.0.6 and confirm if the issue issue still persists.
Thank you.
Thanks for following up @rajasekharporeddy, I'm going to close the issue
Can confirm after upgrading jax-metal
, jax
as well as flax
this error seems to be gone.
Description
Evaluating a simple multi-layer-perceptron (MLP) implemented in
flax
on the same input data and parameters potentially yields non-deterministic outputs on the apple metal device when the function is NOT jitted. When the function is jitted the outputs of the MLP are deterministic and the problem disappears. I was able to verify that this problem is specific to apple metal, as on a linux system with an nvidia gpu, the problem does not occur (with the current jax version).Empirically the problem frequency seems to be worse when the batch dimension is not of shape
2**n
andn > 10
. For example for batch dimensions of 2500 and 5000 the problems occurs frequently. Another empirical observation is that the values are not random but repeat themself. For exampley[0, 0]
is always one ofm
different numbers (empiricallym \approx 3-4
) but it is random which one of the m options ones get, which kind of hints into a memory problem.It is debatable whether this problem is a
jax
,flax
orapple metal plugin
issue. I am happy to file this issue at a different location if preferred.Output WITHOUT JIT => Non-Deterministic Output:
Output WITH JIT => Everything works as expected
What jax/jaxlib version are you using?
0.4.11, 0.4.11
Which accelerator(s) are you using?
Apple Metal
Additional system info?
Python 3.10, Numpy 1.26.0, Platform MacOS / Darwin
NVIDIA GPU info
No response