GaParmar / clean-fid

PyTorch - FID calculation with proper image resizing and quantization steps [CVPR 2022]
https://www.cs.cmu.edu/~clean-fid/
MIT License
894 stars 68 forks source link

'z_batch' size BUG for Generator with 4D/3D tensors as inputs #55

Open EvangelosTikas opened 7 months ago

EvangelosTikas commented 7 months ago

I was trying to load a Generator from Progan implementation in pytorch. This generator receives a tensor of size (batch_size, Z_DIM, 1, 1) as input, where Z_DIM is the latent vector dimension (e.g. 512). The code in cleanfid/fid.py, line 214 accept an input of size z_batch = torch.randn((batch_size, z_dim)).to(device), which is a 2D tensor, suitable only for non-RGB images (like in Vanilla GAN or other examples). I think the code should also support 4D tensors for RGB images, as most GAN implementations accept 4D tensors as inputs to their generators.

[Error]:

compute FID of a model with <name-of-precomputed-statistics-file>
FID model:   0%|                                                                       | 0/1563 [00:00<?, ?it/s]
Traceback (most recent call last):
....
....
  File "/home/user/Documents/GANS/metrics/cleanfid/fid.py", line 251, in fid_model
    np_feats = get_model_features(G, model, mode=mode,
  File "/home/user/Documents/GANS/metrics/cleanfid/fid.py", line 214, in get_model_features
    img_batch = G(z_batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 956, in forward
    return F.conv_transpose2d(
RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv_transpose2d, but got input of size: [32, 512]

code snippet [fid.py]:

"""
Compute the FID stats from a generator model
"""
def get_model_features(G, model, mode="clean", z_dim=512,
        num_gen=50_000, batch_size=128, device=torch.device("cuda"),
        desc="FID model: ", verbose=True, return_z=False,
        custom_image_tranform=None, custom_fn_resize=None):
    if custom_fn_resize is None:
        fn_resize = build_resizer(mode)
    else:
        fn_resize = custom_fn_resize

    # Generate test features
    num_iters = int(np.ceil(num_gen / batch_size))
    l_feats = []
    latents = []
    if verbose:
        pbar = tqdm(range(num_iters), desc=desc)
    else:
        pbar = range(num_iters)
    for idx in pbar:
        with torch.no_grad():
            z_batch = torch.randn((batch_size, z_dim)).to(device)
            if return_z:
                latents.append(z_batch)
            # generated image is in range [0,255]
            img_batch = G(z_batch)
            # split into individual batches for resizing if needed
            if mode != "legacy_tensorflow":
                l_resized_batch = []
                for idx in range(batch_size):