GaParmar / clean-fid

PyTorch - FID calculation with proper image resizing and quantization steps [CVPR 2022]
https://www.cs.cmu.edu/~clean-fid/
MIT License
894 stars 68 forks source link

Images in a big numpy npy file rather than in a folder #52

Open yuanzhi-zhu opened 8 months ago

yuanzhi-zhu commented 8 months ago

Thank you for your great work!! It's really helpful.

I wonder if we can calculate the fid between a numpy file (.npy) that contains an array in the shape (B, C, H, W) and pre-computed datasets statistics?

Massive thanks in advance.

yuanzhi-zhu commented 1 month ago
import argparse
import torch
import torchvision
import numpy as np
import random
import tqdm
from cleanfid import fid

def seed_everywhere(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # for multi-GPU.
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

class ResizeDataset(torch.utils.data.Dataset):
    """
    A placeholder Dataset that enables parallelizing the resize operation
    using multiple CPU cores

    files: list of all files in the folder
    fn_resize: function that takes an np_array as input [0,255]
    """

    def __init__(self, files, mode, size=(299, 299), fdir=None):
        self.files = files
        self.fdir = fdir
        self.transforms = torchvision.transforms.ToTensor()
        self.size = size
        self.fn_resize = fid.build_resizer(mode)
        self.custom_image_tranform = lambda x: x

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        img_np = self.files[i]
        # apply a custom image transform before resizing the image to 299x299
        img_np = self.custom_image_tranform(img_np)
        # fn_resize expects a np array and returns a np array
        img_resized = self.fn_resize(img_np)
        # ToTensor() converts to [0,1] only if input in uint8
        if img_resized.dtype == "uint8":
            img_t = self.transforms(np.array(img_resized)) * 255
        elif img_resized.dtype == "float32":
            img_t = self.transforms(img_resized)

        return img_t

# https://github.com/openai/consistency_models_cifar10/blob/main/jcm/metrics.py#L117
def compute_fid(
    samples,
    feat_model,
    dataset_name="cifar10",
    ref_stat=None,
    dataset_res=32,
    dataset_split="train",
    batch_size=512,
    num_workers=12,
    mode="legacy_tensorflow",
    device=torch.device("cuda:0"),
    seed=0,
):
    dataset = ResizeDataset(samples, mode=mode)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers,
    )
    l_feats = []
    for batch in tqdm.tqdm(dataloader):
        l_feats.append(fid.get_batch_features(batch, feat_model, device))
    np_feats = np.concatenate(l_feats)
    mu = np.mean(np_feats, axis=0)
    sigma = np.cov(np_feats, rowvar=False)
    if ref_stat is not None:
        ref_mu, ref_sigma = ref_stat
    else:
        ref_mu, ref_sigma = fid.get_reference_statistics(
            dataset_name, dataset_res, mode=mode, seed=seed, split=dataset_split
        )
    score = fid.frechet_distance(mu, sigma, ref_mu, ref_sigma)

    return score

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", type=str, default='cifar10')
    parser.add_argument("--data_path", type=str, default='')
    parser.add_argument("--target_path", type=str, default='')
    parser.add_argument("--image_size", type=int, default=32)
    parser.add_argument("--num_channels", type=int, default=3)
    parser.add_argument("--num_samples", type=int, default=50000)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    seed_everywhere(args.seed)

    assert (args.data_path != ''), "data_path must be specified."

    ### basic info
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}; version: {str(torch.version.cuda)}')

    ### build feature extractor
    mode = "legacy_tensorflow"
    feat_model = fid.build_feature_extractor(mode, device)

    # change the seed randomly
    args.seed += np.random.randint(0, 1000000)
    print(f'Using seed: {args.seed};')
    ### set random seed everywhere
    seed_everywhere(args.seed)

    ### load target samples amd calculate reference statistics
    if args.target_path:
        print(f'load target samples from {args.target_path}')
        try:
            ## from https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/
            ref = np.load(args.target_path)
            ref_mu, ref_sigma = ref['mu'], ref['sigma']
            ref_stat = (ref_mu, ref_sigma)
            print(f'reference statistics loaded!')
        except:
            target_samples = np.load(args.target_path)
            target_samples = torch.from_numpy(target_samples)
            target_samples = target_samples / 2 + 0.5
            target_samples = np.clip(target_samples.permute(0, 2, 3, 1).cpu().numpy() * 255., 0, 255).astype(np.uint8)
            target_samples = target_samples.reshape((-1, args.image_size, args.image_size, args.num_channels))
            target_dataset = ResizeDataset(target_samples, mode=mode)
            target_dataloader = torch.utils.data.DataLoader(
                target_dataset,
                batch_size=512,
                shuffle=False,
                drop_last=False,
                num_workers=0,
            )
            l_feats = []
            for batch in tqdm.tqdm(target_dataloader):
                l_feats.append(fid.get_batch_features(batch, feat_model, device))
            np_feats = np.concatenate(l_feats)
            ref_mu = np.mean(np_feats, axis=0)
            ref_sigma = np.cov(np_feats, rowvar=False)
            ref_stat = (ref_mu, ref_sigma)
            print(f'reference statistics calcualted!')
    else:
        ref_stat = None

    ### calculate fid for given data
    print(f'calculate fid for data from {args.data_path}')
    samples = np.load(args.data_path)
    print(f'samples shape: {samples.shape}')
    print(f'samples range: {samples.min()}, {samples.max()}, should be ~ [-1, 1].')
    samples = torch.from_numpy(samples)
    samples = samples / 2 + 0.5
    samples = np.clip(samples.permute(0, 2, 3, 1).cpu().numpy() * 255., 0, 255).astype(np.uint8)
    all_samples = samples.reshape((-1, args.image_size, args.image_size, args.num_channels))
    all_samples = all_samples[: args.num_samples]
    fid_score = compute_fid(
                all_samples,
                mode=mode,
                dataset_name=dataset_name,
                device=device,
                feat_model=feat_model,
                seed=args.seed,
                num_workers=0,
                ref_stat=ref_stat,
            )
    print(f"data_path-{args.data_path} --- FID: {fid_score:0.6f}")