Open yuanzhi-zhu opened 8 months ago
import argparse
import torch
import torchvision
import numpy as np
import random
import tqdm
from cleanfid import fid
def seed_everywhere(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # for multi-GPU.
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
class ResizeDataset(torch.utils.data.Dataset):
"""
A placeholder Dataset that enables parallelizing the resize operation
using multiple CPU cores
files: list of all files in the folder
fn_resize: function that takes an np_array as input [0,255]
"""
def __init__(self, files, mode, size=(299, 299), fdir=None):
self.files = files
self.fdir = fdir
self.transforms = torchvision.transforms.ToTensor()
self.size = size
self.fn_resize = fid.build_resizer(mode)
self.custom_image_tranform = lambda x: x
def __len__(self):
return len(self.files)
def __getitem__(self, i):
img_np = self.files[i]
# apply a custom image transform before resizing the image to 299x299
img_np = self.custom_image_tranform(img_np)
# fn_resize expects a np array and returns a np array
img_resized = self.fn_resize(img_np)
# ToTensor() converts to [0,1] only if input in uint8
if img_resized.dtype == "uint8":
img_t = self.transforms(np.array(img_resized)) * 255
elif img_resized.dtype == "float32":
img_t = self.transforms(img_resized)
return img_t
# https://github.com/openai/consistency_models_cifar10/blob/main/jcm/metrics.py#L117
def compute_fid(
samples,
feat_model,
dataset_name="cifar10",
ref_stat=None,
dataset_res=32,
dataset_split="train",
batch_size=512,
num_workers=12,
mode="legacy_tensorflow",
device=torch.device("cuda:0"),
seed=0,
):
dataset = ResizeDataset(samples, mode=mode)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=num_workers,
)
l_feats = []
for batch in tqdm.tqdm(dataloader):
l_feats.append(fid.get_batch_features(batch, feat_model, device))
np_feats = np.concatenate(l_feats)
mu = np.mean(np_feats, axis=0)
sigma = np.cov(np_feats, rowvar=False)
if ref_stat is not None:
ref_mu, ref_sigma = ref_stat
else:
ref_mu, ref_sigma = fid.get_reference_statistics(
dataset_name, dataset_res, mode=mode, seed=seed, split=dataset_split
)
score = fid.frechet_distance(mu, sigma, ref_mu, ref_sigma)
return score
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default='cifar10')
parser.add_argument("--data_path", type=str, default='')
parser.add_argument("--target_path", type=str, default='')
parser.add_argument("--image_size", type=int, default=32)
parser.add_argument("--num_channels", type=int, default=3)
parser.add_argument("--num_samples", type=int, default=50000)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
seed_everywhere(args.seed)
assert (args.data_path != ''), "data_path must be specified."
### basic info
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}; version: {str(torch.version.cuda)}')
### build feature extractor
mode = "legacy_tensorflow"
feat_model = fid.build_feature_extractor(mode, device)
# change the seed randomly
args.seed += np.random.randint(0, 1000000)
print(f'Using seed: {args.seed};')
### set random seed everywhere
seed_everywhere(args.seed)
### load target samples amd calculate reference statistics
if args.target_path:
print(f'load target samples from {args.target_path}')
try:
## from https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/
ref = np.load(args.target_path)
ref_mu, ref_sigma = ref['mu'], ref['sigma']
ref_stat = (ref_mu, ref_sigma)
print(f'reference statistics loaded!')
except:
target_samples = np.load(args.target_path)
target_samples = torch.from_numpy(target_samples)
target_samples = target_samples / 2 + 0.5
target_samples = np.clip(target_samples.permute(0, 2, 3, 1).cpu().numpy() * 255., 0, 255).astype(np.uint8)
target_samples = target_samples.reshape((-1, args.image_size, args.image_size, args.num_channels))
target_dataset = ResizeDataset(target_samples, mode=mode)
target_dataloader = torch.utils.data.DataLoader(
target_dataset,
batch_size=512,
shuffle=False,
drop_last=False,
num_workers=0,
)
l_feats = []
for batch in tqdm.tqdm(target_dataloader):
l_feats.append(fid.get_batch_features(batch, feat_model, device))
np_feats = np.concatenate(l_feats)
ref_mu = np.mean(np_feats, axis=0)
ref_sigma = np.cov(np_feats, rowvar=False)
ref_stat = (ref_mu, ref_sigma)
print(f'reference statistics calcualted!')
else:
ref_stat = None
### calculate fid for given data
print(f'calculate fid for data from {args.data_path}')
samples = np.load(args.data_path)
print(f'samples shape: {samples.shape}')
print(f'samples range: {samples.min()}, {samples.max()}, should be ~ [-1, 1].')
samples = torch.from_numpy(samples)
samples = samples / 2 + 0.5
samples = np.clip(samples.permute(0, 2, 3, 1).cpu().numpy() * 255., 0, 255).astype(np.uint8)
all_samples = samples.reshape((-1, args.image_size, args.image_size, args.num_channels))
all_samples = all_samples[: args.num_samples]
fid_score = compute_fid(
all_samples,
mode=mode,
dataset_name=dataset_name,
device=device,
feat_model=feat_model,
seed=args.seed,
num_workers=0,
ref_stat=ref_stat,
)
print(f"data_path-{args.data_path} --- FID: {fid_score:0.6f}")
Thank you for your great work!! It's really helpful.
I wonder if we can calculate the fid between a numpy file (.npy) that contains an array in the shape (B, C, H, W) and pre-computed datasets statistics?
Massive thanks in advance.