pytorch / ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
https://pytorch-ignite.ai
BSD 3-Clause "New" or "Revised" License
4.5k stars 608 forks source link

Calculation of FID metric made use of classification probabilities instead of feature vectors #3221

Open JonathanFoo0523 opened 3 months ago

JonathanFoo0523 commented 3 months ago

Based on my understanding, the calculation of fid should make use of feature vector(dim=2048) from the max pool layers of InceptionNet.

However, the FID() metric in fid.py is initiliased as below:

if num_features is None and feature_extractor is None:
    num_features = 1000
    feature_extractor = InceptionModel(return_features=False, device=device)

The InceptionModel return prediction probabilities(dim=1000) if return_features is set to False by specification. update() then make use of this this feature_extractor to get the probabilities instead of feature vectors.

I think feature_extractor should be InceptionModel(return_features=True, device=device) instead?

vfdev-5 commented 3 months ago

@JonathanFoo0523 thanks for the question, yes, FID docstring is incorrect saying that

num_features (Optional[int]) – number of features predicted by the model or the reduced feature vector of the image. Default value is 2048.

Default value is 1000.

I agree that default setup for FID may not match some other implementations. FID score is about computing a distance between features which may be extracted from various models and any internal layers.

We also provide a note in the docs: https://pytorch.org/ignite/generated/ignite.metrics.FID.html how to match results of pytorch_fid.

I think feature_extractor should be InceptionModel(return_features=True, device=device) instead?

So, you are suggesting to do the following, which also makes sense:

if num_features is None and feature_extractor is None:
    num_features = 2048
    feature_extractor = InceptionModel(return_features=True, device=device)

I'm not against this change (+ adding all warnings about a breaking change) but I'd like to understand better the motivation. Do you compare ignite FID results to something else or see that the config with 2048+inception features works better than current default one?