NVlabs / stylegan2-ada-pytorch

StyleGAN2-ADA - Official PyTorch implementation
https://arxiv.org/abs/2006.06676
Other
4.13k stars 1.16k forks source link

How do I change conv1 layer of pretrained model vgg.pt ? #202

Open 89douner opened 2 years ago

89douner commented 2 years ago

In 'projector.py', there is pretrained model ''https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt''.

Input channel of conv 1 layer of the model is 3 channels. But, I wanna convert 3 input channels into 1 input channel.

Usually, the input channel of other pretrained models made by pytorch could be converted into 1 through below code.

self.backbone = models.resnet50(pretrained=pretrained)
self.backbone.conv1 = nn.Conv2d(img_channel, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

checkpoint = torch.load("/home/ubuntu/.cache/torch/checkpoints/resnet50-19c8e357.pth")
conv1_weight = checkpoint['conv1.weight']
self.backbone.conv1.weight = torch.nn.Parameter(conv1_weight.mean(dim=1, keepdim=True))

However, this code doesn't work NVIDIA pretrained model. I think the fail is related to ScriptModule.

Please, give me some tips!

koalahhh commented 1 year ago

I have also encountered the same problem. Could you please tell me your final solution?

koalahhh commented 1 year ago

It seems like its original code for vgg16 net is under keras framework, I try to solve the problem like this:

url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
with dnnlib.util.open_url(url) as f:
    vgg16_model= torch.jit.load(f)
vgg16_pretrained_dict = vgg16_model.state_dict()
print(vgg16_pretrained_dict.items())
vgg16_layer1 = vgg16_pretrained_dict['layers.conv1.weight']
print(vgg16_layer1)
new = torch.zeros(64, 1, 3, 3)
for i, output_channel in enumerate(vgg16_layer1):
    # Gray = 0.299R + 0.587G + 0.114B
    new[i] = 0.299 * output_channel[0] + 0.587 * output_channel[1] + 0.114 * output_channel[2]
vgg16_pretrained_dict['features.0.0.weight'] = new
Conv2dLayer = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
conv1_new_script = torch.jit.script(Conv2dLayer)
vgg16_model.layers.conv1 = conv1_new_script
vgg16 = vgg16_model

but it report errors :RuntimeError:

Expected a value of type 'torch.Conv2dLayer' for field 'conv1', but found 'torch.torch.nn.modules.conv.Conv2d'