pytorch / torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.
https://pytorch.org/torcheval
Other
210 stars 45 forks source link

The FID result cannot be aligned with pytorch-fid/torch-fidelity #192

Open SingleZombie opened 6 months ago

SingleZombie commented 6 months ago

🐛 Describe the bug

Just use FrechetInceptionDistance on any images.

import torch
from torcheval.metrics import FrechetInceptionDistance
imgs_1 = ...
imgs_2 = ...
fid = FrechetInceptionDistance(device=device)
fid.update(imgs_1 , True)
fid.update(imgs_2 , False)
print(fid.compute())

I have found the causes.

  1. The model and weight of InceptionV3 is not strictly same to the TensorFlow version, which is used by most papers. Now I replace it with InceptionV3 model from pytorch-fid.
  2. The insufficient precision of torcheval.metrics.image.fid.FrechetInceptionDistance._calculate_frechet_distance leads to wrong results. I fix this by setting all variables (fake_sum, real_cov_sum, ...) to float64.

With these modifications, the FID results are aligned. I suggest the maintainers reimplement the FID to align it with other libraries, otherwise researchers won't use torcheval to calculate FID.

Versions

PyTorch version: 2.0.1+cu117 Is debug build: False CUDA used to build PyTorch: 11.7 ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64) GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0 Versions of relevant libraries: [pip3] flake8==7.0.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.24.4 [pip3] pytorch-fid==0.3.0 [pip3] pytorch-lightning==1.4.2 [pip3] torch==2.0.1 [pip3] torch-ema==0.3 [pip3] torch-fidelity==0.3.0 [pip3] torchaudio==2.0.2 [pip3] torcheval==0.0.7 [pip3] torchmetrics==0.5.0 [pip3] torchvision==0.15.2 [pip3] triton==2.0.0 [conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] ffmpeg 4.3 hf484d3e_0 pytorch [conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.24.4 pypi_0 pypi [conda] pytorch-fid 0.3.0 pypi_0 pypi [conda] pytorch-lightning 1.4.2 pypi_0 pypi [conda] pytorch-mutex 1.0 cuda pytorch [conda] torch 2.0.1 pypi_0 pypi [conda] torch-ema 0.3 pypi_0 pypi [conda] torch-fidelity 0.3.0 pypi_0 pypi [conda] torchaudio 2.0.2 pypi_0 pypi [conda] torcheval 0.0.7 pypi_0 pypi [conda] torchmetrics 0.5.0 pypi_0 pypi [conda] torchvision 0.15.2 pypi_0 pypi [conda] triton 2.0.0 pypi_0 pypi

AwalkZY commented 6 months ago

Same here. Strongly suggest maintainers aligning the calculation with pytorch-fid or torch-fidelity.