NVlabs / stylegan2-ada-pytorch

StyleGAN2-ADA - Official PyTorch implementation
https://arxiv.org/abs/2006.06676
Other
4.12k stars 1.16k forks source link

Greyscale Projection #51

Open FwiffoSnork opened 3 years ago

FwiffoSnork commented 3 years ago

Projector.py converts your target image to RGB but this of course causes an assertion error with a network trained on greyscale images. I am working on resolving this on my own. Here is the snippet of code that seems to be the first of the greyscale problems:

    # Load target image.
    target_pil = PIL.Image.open(target_fname).convert('RGB')
    w, h = target_pil.size
    s = min(w, h)
    target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
    target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
    target_uint8 = np.array(target_pil, dtype=np.uint8)

    # Optimize projection.
    start_time = perf_counter()
    projected_w_steps = project(
        G,
        target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callabl
FwiffoSnork commented 3 years ago

The VGG feature detector in projector.py is trained for 3 channel images. # Load VGG16 feature detector. url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'

Is there another version, or directions on how to train this for grayscale?

crowsonkb commented 3 years ago

The VGG feature detector in projector.py is trained for 3 channel images. # Load VGG16 feature detector. url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'

Is there another version, or directions on how to train this for grayscale?

Have you tried putting a conversion to 3 channels before the image is fed to the feature detector? Like for an Nx1xHxW image img, torch.cat([img, img, img], dim=1).

koalahhh commented 1 year ago

I have also encountered the same problem. Could you please tell me your final solution?