w86763777 / pytorch-image-generation-metrics

Pytorch implementation of common image generation metrics.
Apache License 2.0
143 stars 16 forks source link
ddpm diffusion-models fid frechet-inception-distance gan generative-adversarial-network inception-score pytorch

Pytorch Implementation of Common Image Generation Metrics



pip install pytorch-image-generation-metrics

Quick Start

from pytorch_image_generation_metrics import get_inception_score, get_fid

images = ... # [N, 3, H, W] normalized to [0, 1]
IS, IS_std = get_inception_score(images)        # Inception Score
FID = get_fid(images, 'path/to/fid_ref.npz') # Frechet Inception Distance

The file path/to/fid_ref.npz is compatiable with the official FID implementation.


The FID implementation is inspired by pytorch-fid.

This repository is developed for personal research. If you find this package useful, please feel free to open issues.


Reproducing Results of Official Implementations on CIFAR-10

Train IS Test IS Train(50k) vs Test(10k)
Official 11.24±0.20 10.98±0.22 3.1508
ours 11.26±0.13 10.97±0.19 3.1525
ours use_torch=True 11.26±0.15 10.97±0.20 3.1457

The results differ slightly from the official implementations due to the framework differences between PyTorch and TensorFlow.


Prepare Statistical Reference for FID

Inception Features

Using torch.Tensor as images

Using PyTorch DataLoader to Provide Images

  1. Use pytorch_image_generation_metrics.ImageDataset to collect images from your storage or use your custom torch.utils.data.Dataset.

    from pytorch_image_generation_metrics import ImageDataset
    from torch.utils.data import DataLoader
    dataset = ImageDataset(path_to_dir, exts=['png', 'jpg'])
    loader = DataLoader(dataset, batch_size=50, num_workers=4)

    You can wrap a generative model in a dataset to support generating images on the fly.

    class GeneratorDataset(Dataset):
        def __init__(self, G, noise_dim):
            self.G = G
            self.noise_dim = noise_dim
        def __len__(self):
            return 50000
        def __getitem__(self, index):
            return self.G(torch.randn(1, self.noise_dim))
    dataset = GeneratorDataset(G, noise_dim=128)
    loader = DataLoader(dataset, batch_size=50, num_workers=0)
  2. Calculate metrics

    from pytorch_image_generation_metrics import (
    # Inception Score
    IS, IS_std = get_inception_score(
    # Frechet Inception Distance
    FID = get_fid(
        loader, 'path/to/fid_ref.npz')
    # Inception Score & Frechet Inception Distance
    (IS, IS_std), FID = get_inception_score_and_fid(
        loader, 'path/to/fid_ref.npz')

Load Images from a Directory

Accelerating Matrix Computation with PyTorch

Tested Versions


This implementation is licensed under the Apache License 2.0.

This implementation is derived from pytorch-fid, licensed under the Apache License 2.0.

FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see https://arxiv.org/abs/1706.08500

The original implementation of FID is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. See https://github.com/bioinf-jku/TTUR.