NVIDIA / framework-reproducibility

Providing reproducibility in deep learning frameworks
Apache License 2.0
423 stars 40 forks source link

Usage of numpy.random.Generator to deal with Data-Loader Parallelism #36

Closed zakajd closed 3 years ago

zakajd commented 3 years ago

Hi @duncanriach Thanks for an awesome repo, which helped to fix a number of issues in my code. For now I can can achieve reproducible results running on GPU with TF 2.4.1 and single worker. Which is good, but painfully slow.

You mentioned, that that there is an option to define pseudorandom number generator state for each parallel thread running. Do you have any code example of how to do that using TF 2.0+? It seems that spawning of workers is handled internally and only way to interact with the process is define num_parallel_calls value in dataset.

Best regards, Jamil

duncanriach commented 3 years ago

Hi Jamil (@zakajd),

Thank you for the appreciation. What method of tf.data.Dataset are you calling with the num_parallel_calls parameter?.

Duncan

zakajd commented 3 years ago

Currently train and val datasets use the following methods:

train_dataset = (
    train_dataset
    .cache()
    .map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    .shuffle(512)
    .map(
        lambda o, e: (resize_and_rescale(o), resize_and_rescale(e)),
        num_parallel_calls=AUTOTUNE
    )
    .batch(cfg.training.batch_size , drop_remainder=True)
    .map(
        lambda o, e: augment(o, e, training=True), num_parallel_calls=AUTOTUNE
    )
    .prefetch(AUTOTUNE)
)
val_dataset = (
    val_dataset
    .cache()
    .map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    .map(
        lambda o, e: (resize_and_rescale(o), resize_and_rescale(e)),
        num_parallel_calls=AUTOTUNE
    )
    .batch(cfg.training.batch_size, drop_remainder=False)
    .map(lambda o, e: augment(o, e, training=False), num_parallel_calls=AUTOTUNE)
    .prefetch(AUTOTUNE)
)

When setting AUTOTUNE parameter to 1 results are reproducible between different runs on GPU and CPU (but results between GPU and CPU are slightly different). When setting AUTOTUNE=4 or AUTOTUNE=tf.data.AUTOTUNE each run results in a different model.

duncanriach commented 3 years ago

Hi @zakajd, I suspect that my original guidance on this may have come from data augmentation pipelines that did not use tf.data.Dataset. One way to solve this problem with tf.data.Dataset is to serialize the generation of pseudorandom parameters and then pass those into the computationally-expensive data augmentation process, which can then be arbitrarily parallelized. Here is some example code:

import tensorflow as tf
import numpy as np

np.set_printoptions(precision=2, floatmode='fixed')

def augment_random(x):
  random = np.float32(np.random.uniform())
  # The addition here represents a computationally-expensive set of operations
  return np.float32(x) + random

def random_param():
  return np.float32(np.random.uniform())

def augment(x, y):
  # The addition here represents a computationally-expensive set of operations
  return tf.add(tf.cast(x, tf.float32), y)

def nondeterministic_pipeline():
  np.random.seed(123)
  dataset = tf.data.Dataset.range(1, 6)
  dataset = dataset.map(
      # From tf.data.Dataset::map documentation: Note that use of
      # tf.numpy_function or tf.py_function in general precludes the possibility
      # of executing user-defined transformations in parallel (because of the
      # Python GIL).
      lambda x: tf.numpy_function(augment_random, inp=[x], Tout=tf.float32),
      num_parallel_calls=5)
  return np.array(list(dataset.as_numpy_iterator()))

def deterministic_pipeline():
  np.random.seed(123)
  dataset = tf.data.Dataset.range(1, 6)
  dataset = dataset.map(
      # From tf.data.Dataset::map documentation: Note that use of
      # tf.numpy_function or tf.py_function in general precludes the possibility
      # of executing user-defined transformations in parallel (because of the
      # Python GIL).
      lambda x: (x, tf.numpy_function(random_param, inp=[], Tout=tf.float32)),
      num_parallel_calls=1)
  dataset = dataset.map(lambda x, y: augment(x, y), num_parallel_calls=5)
  return np.array(list(dataset.as_numpy_iterator()))

result1 = nondeterministic_pipeline()
result2 = nondeterministic_pipeline()
result3 = deterministic_pipeline()
result4 = deterministic_pipeline()

print("\nGenerate random parameters in parallel:")
print("run 1: ",result1)
print("run 2: ",result2)

print("\nGenerate random parameters in series:")
print("run 1: ",result3)
print("run 2: ",result4)

# Generate random parameters in parallel:
# run 1:  [1.55 2.72 3.29 4.70 5.23]
# run 2:  [1.72 2.23 3.55 4.29 5.70]

# Generate random parameters in series:
# run 1:  [1.70 2.29 3.23 4.55 5.72]
# run 2:  [1.70 2.29 3.23 4.55 5.72]
duncanriach commented 3 years ago

I have also updated the documentation for deterministic data-loader parallelism to cover this topic.

duncanriach commented 3 years ago

I just discovered that the solution that I suggested above has been suggested before, on github/tensorflow/tensorflow issue 13932.

Also, now closing this current issue.

zakajd commented 3 years ago

Thanks! I ended up doing something similar, but used tf.random.uniform and implemented a stateless random operations similar to tf.image.stateless_random_contrast.

duncanriach commented 3 years ago

The relatively new stateless random image ops, such as tf.image.stateless_sample_distorted_bounding_box can also be used with this approach: a seed-per-example is generated in a single-worker stage and used with these ops in later, parallelized stages.