pbaylies / stylegan-encoder

StyleGAN Encoder - converts real images to latent space
Other
740 stars 182 forks source link

Stochastic Weight Averaging bug #13

Closed aydao closed 5 years ago

aydao commented 5 years ago

I believe there is a bug in the implementation of stochastic weight averaging. Specifically, inside the apply_swa function for the network code, the scaling appears incorrect because the new model weights are scaled up by the epoch:

        tfutil.set_vars(tfutil.run({self.vars[name]: (src_net.vars[name] * epoch + self.vars[name])/(epoch + 1) for name in names}))

The result is that, regardless of what models swa.py reads in, the last pkl it reads will be scaled so much it overwrites pretty much all of weights in the current model. For example, the tenth model will be scaled massively (i.e., epoch=10) relative to the ones that come before it.

I believe the correct implementation would be:

        scale_new_data = 1.0 / (epoch + 1)
        scale_moving_average = (1.0 - scale_new_data)
        tfutil.set_vars(tfutil.run({self.vars[name]: (src_net.vars[name] * scale_new_data + self.vars[name] * scale_moving_average) for name in names}))

This is derived from the swa authors' repo, with the relevant portions here and here. I'd be happy to submit this fix in a pull request, and wanted to raise the issue here first in case you'd like to handle it differently.

Reproducing the bug should be straightforward. I encountered this problem while experimenting by running swa.py on wildly different models (as a kind of a cheap form of transfer or regularization). For example, I was averaging together gwern's anime model with ak9250's fine art portrait model, among several others, and noticed that the network_avg.pkl always produced samples matching whichever model was last among input pkls files it read. That lead me to inspect the code more closely, and find the original swa code. With the changes above, the output network_avg.pkl now works as expected, producing an average across the input models that appears close to what transfer learning would yield after a few ticks. And as you might expect, applying swa to anime and fine art portraits creates some nightmare material, painterly people with weird cartoony anime eyes :)

Also, I'd like to say thanks for the excellent implementation and updates here. It's really clean and been quite nice to work with.

pbaylies commented 5 years ago

@aydao I noticed that too, bug preserved from the implementation I found, I think; it made a certain amount of sense to me to weight more recent epochs more, but perhaps not that much more, but my solution has been to just do the past 5 or 6 epochs or so. Feel free to submit this as a pull request, I'd be happy to merge it; and thanks for the kind words! (also, cool use of network weight averaging, it always surprises me when these things work!)