NVIDIA / framework-reproducibility

Providing reproducibility in deep learning frameworks
Apache License 2.0
423 stars 40 forks source link

tf.data.experimental.sample_from_datasets non-deterministic in multi-gpu. #39

Open yanniskar opened 2 years ago

yanniskar commented 2 years ago

Problem Overview

I train my model on the same dataset in two different setups: A) single-gpu, B) multi-gpu. The former leads to deterministic results, the latter leads to non-deterministic results. The moment I replace the tf.data.experimental.sample_from_datasets API call with a direct call to tf.data.Datasets, B also becomes determinsitic.


Python: 3.7 Cuda: 11.2 Tensorflow: 2.4.1


Relevant API: https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/data/experimental/sample_from_datasets

def load_resampled_data(dataset_dir: str, split: str, batch_size: int, prepare_example: Callable,
                                         distribution=Distribution.DEFAULT) -> tf.data.Dataset:
     Load the samples in dataset_dir in a shuffled order
     at per-class sampling rates determined by distribution.
     :param dataset_dir: Path to the dataset directory.
     :param split: One of the values in constants.VALID_SPLITS.
     :param batch_size: Number of samples per batch.
     :param prepare_example: Function to apply to every sample.
     :param distribution: Distribution enum indicating the
     distribution over label classes for each epoch.
    assert split in constants.VALID_SPLITS
    class_datasets = []
    tf_records_dir = os.path.join(dataset_dir, split)

    # Load class cardinality information.
    class_cardinality_json = os.path.join(tf_records_dir,
    with file_py.open(class_cardinality_json, 'r') as f:
        class_cardinality = json.load(f)

    # Determine train-time distribution.
    class_distribution = _get_class_distribution(class_cardinality,
    assert round(sum(class_distribution.values()), 2) == 1.0
    print("Train-time class distribution:", class_distribution)

    # Load class-based tf records with re-sampling.
    resampled_distribution = []
    for class_name, class_weight in class_distribution.items():
        tf_record = os.path.join(tf_records_dir, f"{class_name}.tf_record")
        class_dataset = tf.data.TFRecordDataset(tf_record)
        assert class_cardinality[class_name] > 0, class_cardinality
        class_dataset = class_dataset.shuffle(
            min(class_cardinality[class_name], MAX_SHUFFLE_BUFFER_SIZE),
    dataset_cardinality = int(class_cardinality[REFERENCE_CLASS] /
    dataset = tf.data.experimental.sample_from_datasets(
        class_datasets, resampled_distribution, seed=constants.SEED)

    # Elements cannot be processed in parallel because
    # of the stateful non-determinism in the data augmentations.
    dataset = dataset.map(
        prepare_example, num_parallel_calls=1, deterministic=True)
    dataset = dataset.batch(batch_size, drop_remainder=True)

    return dataset.prefetch(1), dataset_cardinality

I cannot provide the full code I use due to it being proprietary, but here is the data loading portion. If more information is needed to root cause this, let me know, and I will see what I can do to provide it. FYI the main code sets all the seeds correctly and disables horovod fusion as suggested by the repo README.

Thanks a lot for the great work on making Tensorflow deterministic. It, along with the documentation provided, has been incredibly useful in my day-to-day work.

duncanriach commented 2 years ago

Hi @yanniskar, thank you for isolating this issue, and for providing a reproducer. Please will you open an issue against the stock TensorFlow repo for this. Please reference this current issue. Once you have opened that, I will close this issue.

duncanriach commented 2 years ago

Also @yanniskar, for my records of customers who value this work, please can I know a bit more about you? Please could you give me your name and/or your affiliation?

yanniskar commented 2 years ago

Hi @yanniskar, thank you for isolating this issue, and for providing a reproducer. Please will you open an issue against the stock TensorFlow repo for this. Please reference this current issue. Once you have opened that, I will close this issue.

Done: https://github.com/tensorflow/tensorflow/issues/53846. To answer your other question, here is my linkedin: https://www.linkedin.com/in/yannis-karakozis-746488116/

Thanks for the help on this one :)

duncanriach commented 2 years ago

Thank you, and my pleasure.

BTW, from TensorFlow version 2.7 onwards, you no longer need to manually serialize (num_parallel_calls=1) your dataset.map preprocessing. @reedwm added a step that ensures that that functionality is deterministic and, in the future, as parallelized as possible (by splitting the stateless, computationally-intense functionality away from the stateful, computationally-mild functionality and parallelizing the former).