toshas / torch-fidelity

High-fidelity performance metrics for generative models in PyTorch
Other
1.01k stars 66 forks source link

FID fails when `feature_layer_fid!=2048` #15

Closed SkafteNicki closed 3 years ago

SkafteNicki commented 3 years ago

Script to reproduce

import torch
from torch.utils.data import Dataset
from torch_fidelity import calculate_metrics

img1 = torch.randint(0, 255, (5, 3, 299, 299), dtype=torch.uint8)
img2 = torch.randint(0, 255, (5, 3, 299, 299), dtype=torch.uint8)

class _ImgDataset(Dataset):
    def __init__(self, imgs):
        self.imgs = imgs

    def __getitem__(self, idx):
        return self.imgs[idx]

    def __len__(self):
        return self.imgs.shape[0]

def main(feature_size):
    torch_fid = calculate_metrics(
        _ImgDataset(img1),
        _ImgDataset(img2),
        fid=True, 
        feature_layer_fid=feature_size
    )
    print(torch_fid)

if __name__ == "__main__":
    main("64") # fails
    main("192") # fails
    main("768") # fails
    main("2048") # succeeds 

will fail with the following assert: image it seems that the shape of the calculated features when feature_layer_fid!=2048 are not 2D but instead: [N, feature_size, 1, 1] it should be simple to fix by introducing a feature.squeeze() somewhere appropriately.

toshas commented 3 years ago

Should be fixed immediately in the wip branch; thank you for spotting the issue!