Open mar-muel opened 2 years ago
I'm guessing, but have not verified, that the order of the psum
reduction is nondeterministic on CPU and depends on the thread schedule.
Here's a simpler reproduction:
import os
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
@partial(jax.pmap, axis_name="i")
def f(x):
return lax.psum(x, axis_name=("i",))
np.random.seed(1234)
x = np.random.randn(8)
print(f(x))
Output:
$ JAX_PLATFORMS=cpu python t.py
[0.7901536 0.7901536 0.7901536 0.7901536 0.7901536 0.7901536 0.7901536
0.7901536]
$ JAX_PLATFORMS=cpu python t.py
[0.79015374 0.79015374 0.79015374 0.79015374 0.79015374 0.79015374
0.79015374 0.79015374]
Hello
I'm experiencing occasional non-deterministic behaviour when running the script below on multi-device CPU (using flag
--xla_force_host_platform_device_count=8
).The script runs 10 attempts of 1) Loading dataset 2) Initialize network 3) pmap train_step function 4) Run a single batch through the network.
Given we are using always the same network initialization and no shuffling/permutation of the dataset I would expect the loss of this first batch to be always the same across attempts. However, with a certain probability (<10%) I get an outlier (second last value):
The problem seems to only occur on CPU. I've tried the same on TPU/TPU pods and there everything was fully deterministic. The outlier value seems to be different every time.
Was curious to know whether someone here has an idea what could be the source of randomness?
Here the script to reproduce the above:
Logs when running script above
Here my system specs: