NVIDIA / framework-reproducibility

Providing reproducibility in deep learning frameworks
Apache License 2.0
423 stars 40 forks source link

garder14/byol-tensorflow2 (batch-norm & softmax/cross-entropy) #35

Open evolu8 opened 3 years ago

evolu8 commented 3 years ago

Running TF 2.4.1 with seeds and envs set I'm getting different results each run for this guy:

https://github.com/garder14/byol-tensorflow2

I currently suspect it's the gradient tape. Not sure how to handle that. Would downgrading TF version help?

Thoughts welcome.

duncanriach commented 3 years ago

Sorry for the delay in responding, Phil; I was on vacation.

I have not run this code or got into debugging it. Just from looking at it, I can see a couple of likely sources of nondeterminism:

  1. tf.keras.layers.BatchNormalization is instantiated in five places in models.py. This layer uses fused batch-norm functionality, which is nondeterministic when being used for fine-tuning. I don't know under exactly what circumstances that is exposed by the Keras layer and, since I wasn't aware of this exposure until now, I have yet to documented it.
  2. tf.nn.sparse_softmax_cross_entropy_with_logits is used in linearevaluation.py on the output of the ClassificationHead. This op will introduce nondeterminism, and there is a work-around for it.

Answering your specific questions/comments:

  1. "Would downgrading TF version help?": No, and downgrading is very unlikely to ever help. We're trying hard to avoid regressions regarding determinism.
  2. "I currently suspect it's the gradient tape": Gradient tape just means something in the backprop. Both of the above-mentioned sources would lead to an introduction of noise in the backprop.