Open FwiffoSnork opened 3 years ago
The VGG feature detector in projector.py is trained for 3 channel images.
# Load VGG16 feature detector. url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
Is there another version, or directions on how to train this for grayscale?
The VGG feature detector in projector.py is trained for 3 channel images.
# Load VGG16 feature detector. url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
Is there another version, or directions on how to train this for grayscale?
Have you tried putting a conversion to 3 channels before the image is fed to the feature detector? Like for an Nx1xHxW image img
, torch.cat([img, img, img], dim=1)
.
I have also encountered the same problem. Could you please tell me your final solution?
Projector.py converts your target image to RGB but this of course causes an assertion error with a network trained on greyscale images. I am working on resolving this on my own. Here is the snippet of code that seems to be the first of the greyscale problems: