I found myself needing to use this metric with PyTorch, but didn't find any existing implementations that use PyTorch tensors. I've used torchmetrics (including for FID) in the past, so I thought that this metric could be a good addition to the library. Additionally, since FID is already present then I imagined an MiFID implementation should not be too difficult to write as an extension of the existing FID code.
Pitch
I would like to add MiFID to the library as an extension of FID. I've already implemented the metric myself as a Metric, and ran some basic tests against the original implementation from the source repo (which uses NumPy for matrix operations, found here).
Although I've already got the metric working (and I'll paste my class code below), I would need some help integrating it into the library and adding tests in order to follow the contribution guidelines and the steps given at the bottom of the Metric page.
Alternatives
The current alternative us just using the NumPy-based solution given in the original source repository.
Additional context
Here is my current code for this feature:
from numpy.lib.type_check import real
from copy import deepcopy
from typing import Any, List, Optional, Union
from torch import Tensor
from torch.autograd import Function
from torch.nn import Module
from torchmetrics.image.fid import NoTrainInceptionV3, MatrixSquareRoot, sqrtm, _compute_fid
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_info
from torchmetrics.utilities.imports import _SCIPY_AVAILABLE, _TORCH_FIDELITY_AVAILABLE
import numpy as np
import scipy
if _SCIPY_AVAILABLE:
import scipy
def _normalize_rows(x: torch.Tensor):
"""
function that normalizes each row of the matrix x to have unit length.
Args:
``x``: A PyTorch tensor of shape (n, m)
Returns:
``x``: The normalized (by row) PyTorch tensor.
"""
return x / torch.norm(x, dim=1, keepdim=True)
def _distance_thresholding(d : torch.Tensor, eps=0.1):
if d < eps:
return d
else:
return 1
def _compute_cosine_distance(features1 : Tensor, features2: Tensor):
features1_nozero = features1[torch.sum(features1, dim=1) != 0]
features2_nozero = features2[torch.sum(features2, dim=1) != 0]
norm_f1 = _normalize_rows(features1_nozero)
norm_f2 = _normalize_rows(features2_nozero)
d = 1.0 - torch.abs(torch.matmul(norm_f1, norm_f2.t()))
mean_min_d = torch.mean(torch.min(d, dim=1).values)
return mean_min_d
def _compute_mifid(mu1: Tensor, sigma1: Tensor, features1: Tensor, mu2 : Tensor, sigma2: Tensor, features2: Tensor):
fid_value = _compute_fid(mu1, sigma1, mu2, sigma2)
distance = _compute_cosine_distance(features1, features2)
distance_thr = _distance_thresholding(distance, eps=0.1)
mifid = fid_value / (distance_thr + 10e-15)
# print("FID_public: ", fid_value, "distance_public: ", distance_thr, "multiplied_public: ", mifid)
return mifid
class MemorizationInformedFrechetInceptionDistance(Metric):
higher_is_better: bool = False
is_differentiable: bool = False
full_state_update: bool = False
real_features_sum: Tensor
real_features_cov_sum: Tensor
real_features_num_samples: Tensor
real_features_stacked: Tensor
fake_features_sum: Tensor
fake_features_cov_sum: Tensor
fake_features_num_samples: Tensor
fake_features_stacked: Tensor
def __init__(
self,
feature: Union[int, Module] = 2048,
reset_real_features: bool = True,
normalize: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if isinstance(feature, int):
num_features = feature
if not _TORCH_FIDELITY_AVAILABLE:
raise ModuleNotFoundError(
"MemorizationInformedFrechetInceptionDistance metric requires that `Torch-fidelity` is installed."
" Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
)
valid_int_input = [64, 192, 768, 2048]
if feature not in valid_int_input:
raise ValueError(
f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
)
self.inception = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])
elif isinstance(feature, Module):
self.inception = feature
dummy_image = torch.randint(0, 255, (1, 3, 299, 299), dtype=torch.uint8, device=self.inception.device)
num_features = self.inception(dummy_image).shape[-1]
else:
raise TypeError("Got unknown input to argument `feature`")
if not isinstance(reset_real_features, bool):
raise ValueError("Argument `reset_real_features` expected to be a bool")
self.reset_real_features = reset_real_features
if not isinstance(normalize, bool):
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize
mx_nb_feets = (num_features, num_features)
self.add_state("real_features_stacked", torch.zeros((0, num_features)).double(), dist_reduce_fx='cat')
self.add_state("fake_features_stacked", torch.zeros((0, num_features)).double(), dist_reduce_fx='cat')
def update(self, imgs: Tensor, real: bool) -> None: # type: ignore
"""Update the state with extracted features."""
imgs = (imgs * 255).byte() if self.normalize else imgs
features = self.inception(imgs)
self.orig_dtype = features.dtype
features = features.double()
if features.dim() == 1:
features = features.unsqueeze(0)
if real:
self.real_features_stacked = torch.cat((self.real_features_stacked, features), dim=0)
else:
self.fake_features_stacked = torch.cat((self.fake_features_stacked, features), dim=0)
def compute(self) -> Tensor:
"""Calculate MiFID score based on accumulated extracted features from the two distributions."""
mean_real = torch.mean(self.real_features_stacked, dim=0).unsqueeze(0)
mean_fake = torch.mean(self.fake_features_stacked, dim=0).unsqueeze(0)
cov_real = torch.cov(self.real_features_stacked.t())
cov_fake = torch.cov(self.fake_features_stacked.t())
return _compute_mifid(mean_real.squeeze(0), cov_real, self.real_features_stacked, mean_fake.squeeze(0), cov_fake, self.fake_features_stacked).to(self.orig_dtype)
def to(self, device):
self.inception = self.inception.to(device)
return super().to(device)
def reset(self) -> None:
if not self.reset_real_features:
real_features_stacked = deepcopy(self.real_features_stacked)
super().reset()
self.real_features_stacked = real_features_stacked
else:
super().reset()
A couple of notes about this:
I added device=self.inception.device to the dummy_image test from the original FID because I found this line gave me errors when my Inception network was already on a CUDA device, but the Metric wasn't yet (as it is just being initialized). This happened with the original FID class as well. I don't know if this works with the broader torchmetrics library but was a workaround for me.
The MiFID metric utilizes the cosine distance between the two feature vectors. I wasn't sure how to break this down so that all the features don't need to be stacked together and used to calculate the distance under 'compute'. Perhaps there's a way to avoid storing all the stacked features, I am hoping someone here would know if that could be done.
Additionally, I haven't yet tested this extensively. But I have found that it gives values that are very similar to those from the source NumPy solution. I just found myself implementing this code for my own project that I'm working on, and wanted to contribute it here :).
🚀 Feature
The goal of this feature would be to implement the MiFID metric initially proposed in On Training Sample Memorization: Lessons from Benchmarking Generative Modeling with a Large-scale Competition. It is an extension of FID, so it could be implemented similarly.
Motivation
I found myself needing to use this metric with PyTorch, but didn't find any existing implementations that use PyTorch tensors. I've used torchmetrics (including for FID) in the past, so I thought that this metric could be a good addition to the library. Additionally, since FID is already present then I imagined an MiFID implementation should not be too difficult to write as an extension of the existing FID code.
Pitch
I would like to add MiFID to the library as an extension of FID. I've already implemented the metric myself as a Metric, and ran some basic tests against the original implementation from the source repo (which uses NumPy for matrix operations, found here).
Although I've already got the metric working (and I'll paste my class code below), I would need some help integrating it into the library and adding tests in order to follow the contribution guidelines and the steps given at the bottom of the Metric page.
Alternatives
The current alternative us just using the NumPy-based solution given in the original source repository.
Additional context
Here is my current code for this feature:
A couple of notes about this:
device=self.inception.device
to the dummy_image test from the original FID because I found this line gave me errors when my Inception network was already on a CUDA device, but the Metric wasn't yet (as it is just being initialized). This happened with the original FID class as well. I don't know if this works with the broader torchmetrics library but was a workaround for me.