tensorflow / tpu

Reference models and tools for Cloud TPUs.
https://cloud.google.com/tpu/
Apache License 2.0
5.2k stars 1.77k forks source link

Reproducibility when training models on TPU VMs #971

Open silvaurus opened 2 years ago

silvaurus commented 2 years ago

Hello!

I'm trying to run training workloads on TPU VMs and are trying to make the training reproducible, i.e. if I train with the same configuration twice, in every single training iteration, all the values are going to be the same.

So far this is what I have done: (1) I have parsed a seed to the shuffling of my dataset. (2) For each Conv/FC layer, I used GlorotNormal and gave the kernel initializer a seed. (3) Assigning seed values to dropout layers.

This is sufficient enough for me to reproduce the results with Tensorflow v2.6.

However, this doesn't work when I add the following data augmentation layers:

tf.keras.layers.experimental.preprocessing.RandomFlip
tf.keras.layers.experimental.preprocessing.RandomCrop

I suspect that it is because of the following warnings:

2022-03-07 21:59:58.943609: W tensorflow/compiler/tf2xla/kernels/random_ops.cc:157] Warning: Using tf.random.truncated_normal with XLA compilation will ignore seeds; consider using tf.random.stateless_truncated_normal instead if reproducible behavior is desired. TruncatedNormal
2022-03-07 21:59:59.000776: I tensorflow/compiler/jit/xla_compilation_cache.cc:334] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
2022-03-07 22:00:02.285634: W tensorflow/compiler/tf2xla/kernels/random_ops.cc:54] Warning: Using tf.random.uniform with XLA compilation will ignore seeds; consider using tf.random.stateless_uniform instead if reproducible behavior is desired. RandomUniform

I do not use trncated_normal or RandomUniform directly. I guess they might be used in these data augmentation layers. Given that I used GlorotNormal to initialize my weight values in Conv/FC, maybe that's why the truncated_normal function does not affect all my kernel initializers, but only affects the data augmentation layers.

Later I move my experiments to Tensorflow v2.8, still on TPU VMs, it turns out that even without these data augmentation layers, my results are still not reproducible.

I suspect that some implementations were changed in the GlorotNormal initializer from v2.6 to v2.8. As I received the same warnings:

2022-03-13 18:17:17.010734: W tensorflow/compiler/tf2xla/kernels/random_ops.cc:167] Warning: Using tf.random.truncated_normal with XLA compilation will ignore seeds; consider using tf.random.stateless_truncated_normal instead if reproducible behavior is desired. TruncatedNormal

And when I replaced GlorotNormal() with Ones() for all CONV/FC layers, my results are reproducible again without data augmentation.

To conclude, it would be extremely helpful if you could help me make my training workload reproducible on TPU VMs (for no matter v2.6 or v2.8). For example, would there be an easy way to configure all tf.random.truncated_normal operations to stateless?

Thank you so much!

silvaurus commented 2 years ago

Just want to follow up that I implemented my own RandomFlip and RandomCrop, such that the problem was fixed in my implementation.

But I'm not sure if this is considered a bug, or it would be also appreciated if you could help me figure out how to use these official data augmentation layers on TPU VMs and archives reproducibility without having to define my own layers.