Closed aydao closed 5 years ago
@aydao I noticed that too, bug preserved from the implementation I found, I think; it made a certain amount of sense to me to weight more recent epochs more, but perhaps not that much more, but my solution has been to just do the past 5 or 6 epochs or so. Feel free to submit this as a pull request, I'd be happy to merge it; and thanks for the kind words! (also, cool use of network weight averaging, it always surprises me when these things work!)
I believe there is a bug in the implementation of stochastic weight averaging. Specifically, inside the
apply_swa
function for the network code, the scaling appears incorrect because the new model weights are scaled up by the epoch:The result is that, regardless of what models swa.py reads in, the last pkl it reads will be scaled so much it overwrites pretty much all of weights in the current model. For example, the tenth model will be scaled massively (i.e., epoch=10) relative to the ones that come before it.
I believe the correct implementation would be:
This is derived from the swa authors' repo, with the relevant portions here and here. I'd be happy to submit this fix in a pull request, and wanted to raise the issue here first in case you'd like to handle it differently.
Reproducing the bug should be straightforward. I encountered this problem while experimenting by running
swa.py
on wildly different models (as a kind of a cheap form of transfer or regularization). For example, I was averaging together gwern's anime model with ak9250's fine art portrait model, among several others, and noticed that the network_avg.pkl always produced samples matching whichever model was last among input pkls files it read. That lead me to inspect the code more closely, and find the original swa code. With the changes above, the output network_avg.pkl now works as expected, producing an average across the input models that appears close to what transfer learning would yield after a few ticks. And as you might expect, applying swa to anime and fine art portraits creates some nightmare material, painterly people with weird cartoony anime eyes :)Also, I'd like to say thanks for the excellent implementation and updates here. It's really clean and been quite nice to work with.