tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 611 forks source link

Guide for Model Averaging is incorrect for Stocastic Weight Averaging #2037

Open grofte opened 4 years ago

grofte commented 4 years ago

System information

Describe the bug

The guide https://www.tensorflow.org/addons/tutorials/average_optimizers_callback#stocastic_averaging does not explain how to do Stocastic Weight Averaging (or the model averaging thing either but I'm not sure what that is supposed to do). It trains three different models but before evaluation the weights from the first model is always loaded due to differences in how the two different Checkpoint callbacks save their weights. And SWA does not require checkpoints, in fact it seemed to break the training somehow. You can manually do SWA if you have a bunch of checkpoints but the Tensorflow Addons way is much easier.

Code to reproduce the issue

Here is the modified notebook https://drive.google.com/file/d/1XxXq6VwoRmvrOQbkqLii9CLtyy8jFTba/view?usp=sharing I've changed SGD to NAdam but I'm not sure if that was necessary. I also changed the filepath that the moving average code loads so it gets the right weights - but again, I don't know if it is working since I'm not sure what it is supposed to do. So I commented out the code.
I also added a validation split so the user can see that the final validation score after SWA is applied is greater than any validation score after an epoch. And the reader gets an idea that the test set appears to be slightly different to the training set in some systematic way.

Addendum

I would consider adding a BatchNormalization layer to the model or adding an additional example with BatchNormalization. You have to complete one forward pass with training=True after running stocastic_avg_nadam.assign_average_vars(model.variables) to complete SWA. It's very important and easy to miss (but no, I don't know how you are supposed to run a forward pass on the model when you use a tf.data.Dataset).

bhack commented 4 years ago

/cc @shreyashpatodia

grofte commented 4 years ago

@shreyashpatodia This needs a bit of attention though. I think it's mostly just copying over what I wrote in the modified notebook and checking that it's okay.

And maybe changing model.variables to model.trainable_variables. I don't know if there's ever an application where you wouldn't want to use model.trainable_variables?

EDIT: Two ways of doing the forward pass to adjust batch normalisation.

  1. Recompile model with SDG optimizer with learning rate 1e-12 and running 1 epoch. Disadvantage is that you don't save the model with the proper optimizer and it will calc and apply gradients which costs time.
  2. Run the model as a forward pass only. Disadvantage is that if you use a TF Dataset then you have to remap it so it only has the features and not the labels. But you don't do that if you are using a different ingestion pipeline. And I'm not sure whether it will be future-proof or Keras will change so a forward pass doesn't change normalization layers.