jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.65k stars 2.82k forks source link

psum is nondeterministic on multiple CPUs #11366

Open mar-muel opened 2 years ago

mar-muel commented 2 years ago

Hello

I'm experiencing occasional non-deterministic behaviour when running the script below on multi-device CPU (using flag --xla_force_host_platform_device_count=8).

The script runs 10 attempts of 1) Loading dataset 2) Initialize network 3) pmap train_step function 4) Run a single batch through the network.

Given we are using always the same network initialization and no shuffling/permutation of the dataset I would expect the loss of this first batch to be always the same across attempts. However, with a certain probability (<10%) I get an outlier (second last value):

Losses:
2.316901683807373
2.316901683807373
2.316901683807373
2.316901683807373
2.316901683807373
2.316901683807373
2.316901683807373
2.316901683807373
2.316901922225952   <---
2.316901683807373

The problem seems to only occur on CPU. I've tried the same on TPU/TPU pods and there everything was fully deterministic. The outlier value seems to be different every time.

Was curious to know whether someone here has an idea what could be the source of randomness?

Here the script to reproduce the above:


from flax import linen as nn
from flax.training import train_state
from flax.training.common_utils import shard
import flax
import jax
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds
import tensorflow as tf
import os
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)-5.5s] [%(name)-12.12s]: %(message)s')
logger = logging.getLogger(__name__)

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

class CNN(nn.Module):
    """A simple CNN model."""

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

def loss_fn(params, images, labels):
    logits = CNN().apply({'params': params}, images)
    one_hot = jax.nn.one_hot(labels, 10)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
    return loss, logits

@jax.jit
def apply_model(state, images, labels):
    """Computes gradients, loss and accuracy for a single batch."""
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params, images, labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy, logits

@jax.jit
def train_step(state, images, labels):
    grads, loss, accuracy, logits = apply_model(state, images, labels)
    grads = jax.lax.pmean(grads, axis_name='device')
    metrics = jax.lax.pmean({'accuracy': accuracy, 'loss': loss}, axis_name='device')
    new_state = state.apply_gradients(grads=grads)
    return new_state, metrics, logits

def train_data_loader(rng, dataset, batch_size, shuffle=True):
    num_samples = len(dataset['image'])
    steps_per_epoch = num_samples // batch_size

    if shuffle:
        # generate a random permuation of shape (num_samples,) integers
        logger.info('Shuffling dataset...')
        perms = jax.random.permutation(rng, num_samples)
    else:
        logger.info('Not shuffling dataset...')
        perms = jnp.arange(num_samples)

    # skip last/incomplete batch
    perms = perms[:steps_per_epoch * batch_size]

    # reshape into steps x batch size
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        # resulting data shape: (batch_size, 28, 28, 1)
        batch = {k: dataset[k][perm] for k in dataset.keys()}
        # shard batch to all devices
        # resulting data shape: (num_devices, batch_size/num_devices, 28, 28, 1)
        batch = shard(batch)
        yield batch

def get_datasets():
    """Load MNIST train and test datasets into memory."""
    ds_builder = tfds.builder('mnist', data_dir='tensorflow_datasets')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    return train_ds, test_ds

def create_train_state(rng, lr=.1, momentum=.9, ):
    """Creates initial `TrainState`."""
    cnn = CNN()
    params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
    tx = optax.sgd(lr, momentum)
    return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

def train_and_evaluate(batch_size=256) -> train_state.TrainState:
    losses = []
    for attempt in range(10):
        train_ds, test_ds = get_datasets()
        rng = jax.random.PRNGKey(0)

        rng, init_rng = jax.random.split(rng)
        state = create_train_state(init_rng)

        # replicate state across all devices
        state = flax.jax_utils.replicate(state)

        rng, input_rng = jax.random.split(rng)
        p_train_step = jax.pmap(train_step, in_axes=(0, 0, 0), axis_name='device')

        for batch in train_data_loader(rng, train_ds, batch_size, shuffle=False):
            state, metrics, logits = p_train_step(state, batch['image'], batch['label'])
            losses.append(metrics['loss'][0].item())
            break

    print('Losses:')
    for loss in losses:
        print(loss)

def main():
    # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')
    logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
    logging.info('JAX local devices: %r', jax.local_devices())
    train_and_evaluate()

if __name__ == '__main__':
    main()

Logs when running script above

2022-07-05 12:14:44,840 [INFO ] [absl        ]: Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
2022-07-05 12:14:44,841 [INFO ] [absl        ]: Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
2022-07-05 12:14:44,841 [INFO ] [absl        ]: Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
2022-07-05 12:14:44,841 [WARNI] [absl        ]: No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
2022-07-05 12:14:44,841 [INFO ] [root        ]: JAX process: 0 / 1
2022-07-05 12:14:44,842 [INFO ] [root        ]: JAX local devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
2022-07-05 12:14:44,842 [INFO ] [absl        ]: Load dataset info from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:14:44,845 [INFO ] [absl        ]: Field info.citation from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:44,845 [INFO ] [absl        ]: Field info.splits from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:44,845 [INFO ] [absl        ]: Field info.supervised_keys from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:44,845 [INFO ] [absl        ]: Field info.module_name from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:44,846 [INFO ] [absl        ]: Reusing dataset mnist (tensorflow_datasets/mnist/3.0.1)
2022-07-05 12:14:44,846 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split train, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:14:46,163 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split test, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:14:48,277 [INFO ] [__main__    ]: Not shuffling dataset...
/Users/martin/miniconda3/envs/jax/lib/python3.8/site-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
  warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
2022-07-05 12:14:51,149 [INFO ] [absl        ]: Load dataset info from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:14:51,152 [INFO ] [absl        ]: Field info.citation from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:51,153 [INFO ] [absl        ]: Field info.splits from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:51,153 [INFO ] [absl        ]: Field info.supervised_keys from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:51,153 [INFO ] [absl        ]: Field info.module_name from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:51,153 [INFO ] [absl        ]: Reusing dataset mnist (tensorflow_datasets/mnist/3.0.1)
2022-07-05 12:14:51,153 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split train, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:14:52,543 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split test, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:14:53,183 [INFO ] [__main__    ]: Not shuffling dataset...
2022-07-05 12:14:54,013 [INFO ] [absl        ]: Load dataset info from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:14:54,016 [INFO ] [absl        ]: Field info.citation from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:54,016 [INFO ] [absl        ]: Field info.splits from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:54,016 [INFO ] [absl        ]: Field info.supervised_keys from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:54,016 [INFO ] [absl        ]: Field info.module_name from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:54,016 [INFO ] [absl        ]: Reusing dataset mnist (tensorflow_datasets/mnist/3.0.1)
2022-07-05 12:14:54,017 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split train, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:14:55,272 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split test, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:14:55,721 [INFO ] [__main__    ]: Not shuffling dataset...
2022-07-05 12:14:56,557 [INFO ] [absl        ]: Load dataset info from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:14:56,560 [INFO ] [absl        ]: Field info.citation from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:56,560 [INFO ] [absl        ]: Field info.splits from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:56,560 [INFO ] [absl        ]: Field info.supervised_keys from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:56,560 [INFO ] [absl        ]: Field info.module_name from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:56,560 [INFO ] [absl        ]: Reusing dataset mnist (tensorflow_datasets/mnist/3.0.1)
2022-07-05 12:14:56,560 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split train, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:14:57,950 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split test, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:14:58,402 [INFO ] [__main__    ]: Not shuffling dataset...
2022-07-05 12:14:59,258 [INFO ] [absl        ]: Load dataset info from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:14:59,262 [INFO ] [absl        ]: Field info.citation from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:59,262 [INFO ] [absl        ]: Field info.splits from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:59,262 [INFO ] [absl        ]: Field info.supervised_keys from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:59,262 [INFO ] [absl        ]: Field info.module_name from disk and from code do not match. Keeping the one from code.
2022-07-05 12:14:59,262 [INFO ] [absl        ]: Reusing dataset mnist (tensorflow_datasets/mnist/3.0.1)
2022-07-05 12:14:59,262 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split train, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:00,624 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split test, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:01,198 [INFO ] [__main__    ]: Not shuffling dataset...
2022-07-05 12:15:02,048 [INFO ] [absl        ]: Load dataset info from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:02,052 [INFO ] [absl        ]: Field info.citation from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:02,052 [INFO ] [absl        ]: Field info.splits from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:02,052 [INFO ] [absl        ]: Field info.supervised_keys from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:02,052 [INFO ] [absl        ]: Field info.module_name from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:02,052 [INFO ] [absl        ]: Reusing dataset mnist (tensorflow_datasets/mnist/3.0.1)
2022-07-05 12:15:02,052 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split train, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:03,462 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split test, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:03,940 [INFO ] [__main__    ]: Not shuffling dataset...
2022-07-05 12:15:04,815 [INFO ] [absl        ]: Load dataset info from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:04,817 [INFO ] [absl        ]: Field info.citation from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:04,817 [INFO ] [absl        ]: Field info.splits from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:04,817 [INFO ] [absl        ]: Field info.supervised_keys from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:04,817 [INFO ] [absl        ]: Field info.module_name from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:04,817 [INFO ] [absl        ]: Reusing dataset mnist (tensorflow_datasets/mnist/3.0.1)
2022-07-05 12:15:04,817 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split train, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:06,126 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split test, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:06,583 [INFO ] [__main__    ]: Not shuffling dataset...
2022-07-05 12:15:07,432 [INFO ] [absl        ]: Load dataset info from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:07,435 [INFO ] [absl        ]: Field info.citation from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:07,435 [INFO ] [absl        ]: Field info.splits from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:07,435 [INFO ] [absl        ]: Field info.supervised_keys from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:07,436 [INFO ] [absl        ]: Field info.module_name from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:07,436 [INFO ] [absl        ]: Reusing dataset mnist (tensorflow_datasets/mnist/3.0.1)
2022-07-05 12:15:07,436 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split train, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:08,768 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split test, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:09,221 [INFO ] [__main__    ]: Not shuffling dataset...
2022-07-05 12:15:10,165 [INFO ] [absl        ]: Load dataset info from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:10,168 [INFO ] [absl        ]: Field info.citation from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:10,168 [INFO ] [absl        ]: Field info.splits from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:10,168 [INFO ] [absl        ]: Field info.supervised_keys from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:10,168 [INFO ] [absl        ]: Field info.module_name from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:10,168 [INFO ] [absl        ]: Reusing dataset mnist (tensorflow_datasets/mnist/3.0.1)
2022-07-05 12:15:10,168 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split train, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:11,398 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split test, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:11,941 [INFO ] [__main__    ]: Not shuffling dataset...
2022-07-05 12:15:12,807 [INFO ] [absl        ]: Load dataset info from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:12,810 [INFO ] [absl        ]: Field info.citation from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:12,810 [INFO ] [absl        ]: Field info.splits from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:12,810 [INFO ] [absl        ]: Field info.supervised_keys from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:12,810 [INFO ] [absl        ]: Field info.module_name from disk and from code do not match. Keeping the one from code.
2022-07-05 12:15:12,810 [INFO ] [absl        ]: Reusing dataset mnist (tensorflow_datasets/mnist/3.0.1)
2022-07-05 12:15:12,810 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split train, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:14,204 [INFO ] [absl        ]: Constructing tf.data.Dataset mnist for split test, from tensorflow_datasets/mnist/3.0.1
2022-07-05 12:15:14,663 [INFO ] [__main__    ]: Not shuffling dataset...
Losses:
2.316901683807373
2.316901683807373
2.316901683807373
2.316901683807373
2.316901683807373
2.316901683807373
2.316901683807373
2.316901683807373
2.316901922225952
2.316901683807373

Here my system specs:

$ python --version
Python 3.8.0
$ pip freeze | grep -E 'jax|flax'
flax==0.4.1
jax==0.3.14
jaxlib==0.3.10
$ uname -a
Darwin Martins-MacBook-Pro.local 20.6.0 Darwin Kernel Version 20.6.0: Tue Apr 19 21:04:45 PDT 2022; root:xnu-7195.141.29~1/RELEASE_X86_64 x86_64
hawkinsp commented 2 years ago

I'm guessing, but have not verified, that the order of the psum reduction is nondeterministic on CPU and depends on the thread schedule.

hawkinsp commented 2 years ago

Here's a simpler reproduction:

import os
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

@partial(jax.pmap, axis_name="i")
def f(x):
  return lax.psum(x, axis_name=("i",))

np.random.seed(1234)
x = np.random.randn(8)

print(f(x))

Output:

$ JAX_PLATFORMS=cpu python t.py
[0.7901536 0.7901536 0.7901536 0.7901536 0.7901536 0.7901536 0.7901536
 0.7901536]
$ JAX_PLATFORMS=cpu python t.py
[0.79015374 0.79015374 0.79015374 0.79015374 0.79015374 0.79015374
 0.79015374 0.79015374]