eladrich / pixel2style2pixel

Official Implementation for "Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation" (CVPR 2021) presenting the pixel2style2pixel (pSp) framework
https://eladrich.github.io/pixel2style2pixel/
MIT License
3.2k stars 568 forks source link

Using my own pretrained model from vanilla StyleGAN2 #309

Closed demiahmed closed 1 year ago

demiahmed commented 1 year ago

Can I use my own pretrained model .pt which I trained using vanilla StyleGAN2 for encoding my own images?

Or Should I be training the encoder again from this repo?

yuval-alaluf commented 1 year ago

When you say pre-trained model, are you referring to a StyleGAN2 generator or an encoder? If you have a different StyleGAN2 generator than what is used in this repo, then you will need to train your own new encoder since the latent spaces will be different.

demiahmed commented 1 year ago

Sorry I wasn't clear earlier. I was referring to a pre-trained StyleGAN2 Generator model which I trained using the StyleGAN2-ADA official repo. When I try to run the pSp training, I get a tensor size mismatch and unable to figure out why even though the configs like resolution and z-vector matches.

 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
        size mismatch for convs.6.conv.weight: copying a param with shape torch.Size([1, 256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 512, 512, 3, 3]).
        size mismatch for convs.6.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for convs.7.conv.weight: copying a param with shape torch.Size([1, 256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 512, 512, 3, 3]).
        size mismatch for convs.7.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
        size mismatch for convs.7.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for convs.7.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for convs.8.conv.weight: copying a param with shape torch.Size([1, 128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 512, 3, 3]).
        size mismatch for convs.8.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
        size mismatch for convs.8.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for convs.8.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
        size mismatch for convs.9.conv.weight: copying a param with shape torch.Size([1, 128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 256, 3, 3]).
        size mismatch for convs.9.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
        size mismatch for convs.9.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
        size mismatch for convs.9.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
        size mismatch for convs.10.conv.weight: copying a param with shape torch.Size([1, 64, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 256, 3, 3]).
        size mismatch for convs.10.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
        size mismatch for convs.10.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
        size mismatch for convs.10.activate.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for convs.11.conv.weight: copying a param with shape torch.Size([1, 64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 128, 3, 3]).
        size mismatch for convs.11.conv.modulation.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
        size mismatch for convs.11.conv.modulation.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for convs.11.activate.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for to_rgbs.3.conv.weight: copying a param with shape torch.Size([1, 3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 512, 1, 1]).
        size mismatch for to_rgbs.3.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
        size mismatch for to_rgbs.3.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for to_rgbs.4.conv.weight: copying a param with shape torch.Size([1, 3, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 256, 1, 1]).
        size mismatch for to_rgbs.4.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
        size mismatch for to_rgbs.4.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
        size mismatch for to_rgbs.5.conv.weight: copying a param with shape torch.Size([1, 3, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 128, 1, 1]).
        size mismatch for to_rgbs.5.conv.modulation.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
        size mismatch for to_rgbs.5.conv.modulation.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).

Any help would be highly appreciated

yuval-alaluf commented 1 year ago

It seems like you need to change the channel multiplier for StyleGAN. Please take a look at the following line: https://github.com/eladrich/pixel2style2pixel/blob/5cfff385beb7b95fbce775662b48fcc80081928d/models/psp.py#L31 You should add, channel_multiplier=1 when defining the Generator class. Let me know if this solves the issue.