Open silvaurus opened 2 years ago
Just want to follow up that I implemented my own RandomFlip
and RandomCrop
, such that the problem was fixed in my implementation.
But I'm not sure if this is considered a bug, or it would be also appreciated if you could help me figure out how to use these official data augmentation layers on TPU VMs and archives reproducibility without having to define my own layers.
Hello!
I'm trying to run training workloads on TPU VMs and are trying to make the training reproducible, i.e. if I train with the same configuration twice, in every single training iteration, all the values are going to be the same.
So far this is what I have done: (1) I have parsed a seed to the shuffling of my dataset. (2) For each Conv/FC layer, I used GlorotNormal and gave the kernel initializer a seed. (3) Assigning seed values to dropout layers.
This is sufficient enough for me to reproduce the results with Tensorflow v2.6.
However, this doesn't work when I add the following data augmentation layers:
I suspect that it is because of the following warnings:
I do not use trncated_normal or RandomUniform directly. I guess they might be used in these data augmentation layers. Given that I used GlorotNormal to initialize my weight values in Conv/FC, maybe that's why the truncated_normal function does not affect all my kernel initializers, but only affects the data augmentation layers.
Later I move my experiments to Tensorflow v2.8, still on TPU VMs, it turns out that even without these data augmentation layers, my results are still not reproducible.
I suspect that some implementations were changed in the GlorotNormal initializer from v2.6 to v2.8. As I received the same warnings:
And when I replaced GlorotNormal() with Ones() for all CONV/FC layers, my results are reproducible again without data augmentation.
To conclude, it would be extremely helpful if you could help me make my training workload reproducible on TPU VMs (for no matter v2.6 or v2.8). For example, would there be an easy way to configure all tf.random.truncated_normal operations to stateless?
Thank you so much!