ajbrock / BigGAN-PyTorch

The author's officially unofficial PyTorch BigGAN implementation.
MIT License
2.84k stars 470 forks source link

Critical: Code report training FID. #77

Closed AlexanderMath closed 3 years ago

AlexanderMath commented 3 years ago

Thanks for an amazing job, I rarely find open-source code of such high quality.

It seems to me the Inception activation moments are precomputed on the training data.

Question 1. Is it correct that the moments are computed on training data?

Question 2. Is this also the case for the TF code used for the paper, or is this specific for the PyTorch code?

Question 3. Is there any reason why you would prefer training FID instead of validation FID?

I apologize if I missed something.

ajbrock commented 3 years ago
  1. Yep!
  2. Yep!
  3. FID is not a measure of generalization like a likelihood would be, it's a measure of similarity between two datasets. While some people prefer to use FID relative to the validation set, this doesn't really provide any meaningful "defense against overfitting" like a likelihood or test accuracy would. I'd expect FID against the val set to just be slightly lower but to not result in any meaningful change in trends for model or sample comparison.
AlexanderMath commented 3 years ago

FID is not a measure of generalization like a likelihood would be ...

Newer models seem to attain fid(model, train) < fid(train, val), i.e., model and training samples are closer to each other than training and validation samples. I fail to see why this is not indicative of overfitting.

As scientists, we hypothesize our model generalize, then attempt to falsify the generalization hypothesis using a test set. If FID allow us to falsify the generalization hypothesis, I fail to see why it isn't a meaningful measure of overfitting.

ajbrock commented 3 years ago

The issue is that FID is only comparing a set of samples generated by your model to some other set of samples, which, simply doesn't say anything about generalization. A reasonably appropriate test (and a topic which much research has pursued but without a huge amount of success) would be to invert the model and try to find the nearest point z which minimizes p(x|z), then evaluate the likelihood p(z). This has its own set of flaws (typically you can't invert the model analytically so you're minimizing a reconstruction loss with SGD, which is of course quite fraught) but it will at least tell you "can my model 'explain away' this new sample to some degree?"

Note also that it's important to remember that in general just because a model has something akin to "lower training loss than validation loss" does not mean that the model has overfit. It's a common fallacy that the generalization gap is always indicative of overfitting, and I think that's going to be especially misleading when looking at a measure like FID which compares bulk statistics in a very particular feature space.

One interesting experiment might be to measure the triplets of FID(model, train), FID(model, val), FID(train, val) for a flow and compare the train and test likelihood for this model as you intentionally overfit it (remove regularization, etc). While you'll likely see that FID(model, train) < FID(model, val) whether or not this will be consistently correlated with L_train vs L_test is a whole nother question :)

I'm not saying these models aren't overfitting, for the record--GANs in particular operate pretty much exclusively in the "weaponized overfitting" regime. FID, however, is not really the metric you should be using to try and measure this, and the difference between reporting FID(model, train) or FID(model, val) is going to be a placebo that isn't actually providing what insights one might hope it would provide.

AlexanderMath commented 3 years ago

Thanks for taking the time to answer.

The issue is that FID is only comparing a set of samples generated by your model to some other set of samples, which, simply doesn't say anything about generalization.

I think we mean different things by generalization. By generalization error, I mean some notion of difference between the model distribution and data distribution . This notion could be anything from Wasserstein to f-divergences like KL or JS. But Wasserstein [1] and all f-divergences [2] can be approximated by comparing a set of samples generated by model to some other set of samples .

In other words. Sampling can be used to approximate generalization error. I therefore don't see why a reliance on sampling implies that FID can't say anything about generalization.

Question 1. Do you mean something different by generalization?

I wrote comments and questions to your entire message, but I think it's best to first post them when we settle on the above.

[1] use kantorovich-rubinstein duality as done in wgan https://arxiv.org/abs/1701.07875 [2] use variational divergence estimation as done in fgan https://arxiv.org/pdf/1606.00709