NVIDIA / framework-reproducibility

Providing reproducibility in deep learning frameworks
Apache License 2.0
423 stars 40 forks source link

Non-reproducible model training results with TensorFlow 2.5 #38

Closed brandhsu closed 2 years ago

brandhsu commented 2 years ago

Having issues attaining reproducible results after running model training multiple times with the same dataset, model, training loop, and seed.

Environment

Code

Function to set random seed.

def set_seeds(seed=0):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

    os.environ["TF_DETERMINISTIC_OPS"] = "1"
    os.environ["TF_CUDNN_DETERMINISTIC"] = "1"

    # called once at beginning of program
    tf.config.threading.set_inter_op_parallelism_threads(1)
    tf.config.threading.set_intra_op_parallelism_threads(1)

def sum_of_weights(model):
    return sum(map(lambda x: x.sum(), model.get_weights()))

Model training loop, here the same exact model is repetitively trained over the same dataset and seed.

data = {}
models = {}
pre_weights = {}
post_weights = {}

# Loop over model hyperparameters
for i, hyperparams in enumerate(all_hyperparams):

    # Load dataset generators
    client = Dataset(hyperparams).get_client(0)
    gen_train, gen_valid = client.create_generators()

    # Create model
    model = Model(client).create()
    pre_weights[i] = reproducibility.sum_of_weights(model) # store sum of initial weights

    # Create a trainer
    trainer = Trainer(hyperparams)

    # Train model (under the hood same as model.fi
    history = trainer.fit(
            model,
            gen_train,
            gen_valid,
            iters=500,
            steps_per_epoch=100,
            validation_freq=5,
            callbacks=None,
    )
    post_weights[i] = reproducibility.sum_of_weights(history.model) # store sum of trained weights

    models[i] = model # store model

    x, y = next(gen_train)
    data[i] = np.sum(x['dat']) # store sum of a batch of data

The set_seeds function is called every time the classes (Dataset, Model, Trainer ) are instantiated __init__. The loss function being used is tf.keras.losses.SparseCategoricalCrossentropy.


Output

# sum of initial model weights
{0: 1377.537437170744, 1: 1377.537437170744, 2: 1377.537437170744, 3: 1377.537437170744, 4: 1377.537437170744}

# sum of trained model weights
{0: 1167.8577085053264, 1: 1167.3095767943782, 2: 1165.1663435424903, 3: 1165.0273841904536, 4: 1167.2645659474476}

# sum of a batch of data
{0: -3.637978807091713e-12, 1: -3.637978807091713e-12, 2: -3.637978807091713e-12, 3: -3.637978807091713e-12, 4: -3.637978807091713e-12}

The initial weights seem to be seeded correctly but the weights after training are not, I wonder what the issue could be...

Am I missing something? This was all trained on the same gpu and kernel session.

I recall that the reduce_ operations were non-deterministic at one point, but I assume they have been fixed. Are there any recommended resources that shed light on the potentially non-deterministic operations and suggestions to ensure reproducibility?

Thanks in advance 😊, it is of tremendous value that we are able to reproduce results whether it be for for general model comparison or even life impacting applications.

yanniskar commented 2 years ago

According to the presentation here, there is a probe tool in the tfdeterminism package. I have not used it myself, but that can help you root cause at which point in your architecture the non-determinism is introduced.

duncanriach commented 2 years ago

Hi @Brandhsu, please will you run using TensorFlow version 2.8.0-rc0. Enable op determinism using tf.config.experimental.enable_op_determinism instead of TF_DETERMINISTIC_OPS. This latest version of TensorFlow will throw an exception if a nondeterministic op is being used. Note that, if your model uses convolutions, there is a determinism bug that was introduced in version 2.5 that still exists in 2.8.0-rc0 (see stock TF issue 53771).

Regarding @yanniskar's comment above, the nondeterminism debug tool has not (yet) been released publicly.

brandhsu commented 2 years ago

Hi @duncanriach it seems you and your team were spot on with the issue being due to the convolution non-determinism bug (my model was indeed using convolutions). I was able to achieve reproducible results with version 2.8.0-rc0.

The following changes were made.

pip install tensorflow==2.8.0rc0
def set_seeds(seed=0):
    if tf_version >= 2.8:
        tf.keras.utils.set_random_seed(seed)
        tf.config.experimental.enable_op_determinism()

The output after training is shown below.

# sum of intial weights
{0: 1377.537437170744, 1: 1377.537437170744, 2: 1377.537437170744, 3: 1377.537437170744, 4: 1377.537437170744}

# sum of final weights
{0: 1266.9917370815285, 1: 1266.9917370815285, 2: 1266.9917370815285, 3: 1266.9917370815285, 4: 1266.9917370815285}

# sum of a batch of data
{0: -9.094947017729282e-12, 1: -9.094947017729282e-12, 2: -9.094947017729282e-12, 3: -9.094947017729282e-12, 4: -9.094947017729282e-12}

I will run the same models on different gpus and report back if the behavior is also reproducible.

Thanks for you and your team's work on this issue, looking forward to the stable patch!

brandhsu commented 2 years ago

Awesome, the models were reproducible across different gpus!

duncanriach commented 2 years ago

Wonderful news, @Brandhsu. Thanks for reporting back and closing.

I would like to clarify that many possible sources of nondeterminism have been addressed since TensorFlow version 2.5 (as represented by 2.8.0-rc0) and any one of these (and possibly more than one) could have been thwarting your reproducibility. The 2.8.0-rc0 release actually still contains the bug related to deterministic selection of convolution algorithms (#53771), and, since (I presume) you did not apply the work-around (TF_CUDNN_USE_FRONTEND=1), that bug was not the source of the nondeterminism that you were seeing.

If you, or anyone else reading this, is interested, the release notes for 2.6, 2.7, and 2.8 contain info about all the determinism-related changes in those versions.

cbhushan commented 1 year ago

Awesome, the models were reproducible across different gpus!

@Brandhsu , I know this has been a while, but do you recall if these different GPUs had same or different architecture? Thanks.

brandhsu commented 1 year ago

Hi @cbhushan, they were different GPUs, but same architecture.