Open grofte opened 4 years ago
/cc @shreyashpatodia
@shreyashpatodia This needs a bit of attention though. I think it's mostly just copying over what I wrote in the modified notebook and checking that it's okay.
And maybe changing model.variables
to model.trainable_variables
. I don't know if there's ever an application where you wouldn't want to use model.trainable_variables
?
EDIT: Two ways of doing the forward pass to adjust batch normalisation.
System information
Describe the bug
The guide https://www.tensorflow.org/addons/tutorials/average_optimizers_callback#stocastic_averaging does not explain how to do Stocastic Weight Averaging (or the model averaging thing either but I'm not sure what that is supposed to do). It trains three different models but before evaluation the weights from the first model is always loaded due to differences in how the two different Checkpoint callbacks save their weights. And SWA does not require checkpoints, in fact it seemed to break the training somehow. You can manually do SWA if you have a bunch of checkpoints but the Tensorflow Addons way is much easier.
Code to reproduce the issue
Here is the modified notebook https://drive.google.com/file/d/1XxXq6VwoRmvrOQbkqLii9CLtyy8jFTba/view?usp=sharing I've changed SGD to NAdam but I'm not sure if that was necessary. I also changed the filepath that the moving average code loads so it gets the right weights - but again, I don't know if it is working since I'm not sure what it is supposed to do. So I commented out the code.
I also added a validation split so the user can see that the final validation score after SWA is applied is greater than any validation score after an epoch. And the reader gets an idea that the test set appears to be slightly different to the training set in some systematic way.
Addendum
I would consider adding a BatchNormalization layer to the model or adding an additional example with BatchNormalization. You have to complete one forward pass with
training=True
after runningstocastic_avg_nadam.assign_average_vars(model.variables)
to complete SWA. It's very important and easy to miss (but no, I don't know how you are supposed to run a forward pass on the model when you use a tf.data.Dataset).