NVlabs / stylegan2-ada-pytorch

StyleGAN2-ADA - Official PyTorch implementation
https://arxiv.org/abs/2006.06676
Other
4.07k stars 1.16k forks source link

Converting Stylegan2-ada (or tensorflow) models into Pytorch or Stylegan2 #85

Open Randy-H0 opened 3 years ago

Randy-H0 commented 3 years ago

Hello, so I currently do have a problem. I noticed that in the Stylegan2-pytorch repo there is a "convert_weights" file that allows you to convert your Stylegan2 model into Pytorch. Now this is really handy to do if you want to use Ganspace or something.

Now in this repo there is not that option that allows you to convert your model weights into a pytorch model. I tried brute forcing it with some other converters but that really did not turn out good.

Is there any other converter that can convert your model weights into pytorch or Stylegan2? I'd love to use Ganspace with it.

Now if there is not a way to do this, is it possible to instead hook ganspace up to Stylegan2-ada instead of ganspace using Pytorch?

Thanks so much!

katyonats commented 3 years ago

Hi, I've got exactly the same issue. Impatiently waiting for response too :)

MoemaMike commented 3 years ago

just trying to understand the problem as an observer of the issue. I trained several models using stylegan2-ada then recently used stylegan2-ada-pytorch generate and train with the stylegan2-ada models (networks) just fine without doing any adjustment. Are you saying it is not bi directional compatibility? That the pytorch codebase can handle the tensorflow models but not vice versa?

Randy-H0 commented 3 years ago

just trying to understand the problem as an observer of the issue. I trained several models using stylegan2-ada then recently used stylegan2-ada-pytorch generate and train with the stylegan2-ada models (networks) just fine without doing any adjustment. Are you saying it is not bi directional compatibility? That the pytorch codebase can handle the tensorflow models but not vice versa?

I cannot convert a stylegan2-ada model into a pytorch model. I don't see any convert_weights python file to use that to convert it into a pt model

MoemaMike commented 3 years ago

that is the part i do not understand. I did not so anything to "convert". I just invoked pytorch train resume and generate and specified my tensorflow generated model in the --network param and they just ran fine .

https://github.com/NVlabs/stylegan2-ada-pytorch/issues/74

Lamply commented 3 years ago

that is the part i do not understand. I did not so anything to "convert". I just invoked pytorch train resume and generate and specified my tensorflow generated model in the --network param and they just ran fine .

74

I guess the pytorch model used in stylegan2-ada-pytorch is different from https://github.com/rosinality/stylegan2-pytorch. The latter implementation has been quite a lot used in other pytorch projects base on stylegan2.

dvschultz commented 3 years ago

Justin Pinkney wrote a way to export weights from official pt pkl to the latest rosinality structure. I have a copy here: https://github.com/dvschultz/stylegan2-ada-pytorch/blob/main/export_weights.py

Note that this will not work for Ganspace because Ganspace uses an older version of the rosinality model. I recommend checking out Closed Form Factorization which does a similar thing but with a different process (its faster too). I have a port of it in my repo here (doesn’t require conversion): https://github.com/dvschultz/stylegan2-ada-pytorch/blob/main/closed_form_factorization.py

skymanaditya1 commented 2 years ago

Hi @dvschultz the link that you shared above indeed converts a stylegan2-ada-pytorch model to stylegan2-pytorch model. However, when trying to load the converted model with stylegan2-pytorch, I get the key mismatch errors.

RuntimeError: Error(s) in loading state_dict for Generator: Missing key(s) in state_dict: "style.3.weight", "style.3.bias", "style.4.weight", "style.4.bias", "style.5.weight", "style.5.bias", "style.6.weight", "style.6.bias", "style.7.weight", "style.7.bias", "style.8.weight", "style.8.bias". size mismatch for convs.6.conv.weight: copying a param with shape torch.Size([1, 256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 512, 512, 3, 3]). size mismatch for convs.6.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for convs.7.conv.weight: copying a param with shape torch.Size([1, 256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 512, 512, 3, 3]). size mismatch for convs.7.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for convs.7.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for convs.7.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for convs.8.conv.weight: copying a param with shape torch.Size([1, 128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 512, 3, 3]). size mismatch for convs.8.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for convs.8.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for convs.8.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for convs.9.conv.weight: copying a param with shape torch.Size([1, 128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 256, 3, 3]). size mismatch for convs.9.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]). size mismatch for convs.9.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for convs.9.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for convs.10.conv.weight: copying a param with shape torch.Size([1, 64, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 256, 3, 3]). size mismatch for convs.10.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]). size mismatch for convs.10.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for convs.10.activate.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]). size mismatch for convs.11.conv.weight: copying a param with shape torch.Size([1, 64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 128, 3, 3]). size mismatch for convs.11.conv.modulation.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]). size mismatch for convs.11.conv.modulation.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]). size mismatch for convs.11.activate.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]). size mismatch for to_rgbs.3.conv.weight: copying a param with shape torch.Size([1, 3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 512, 1, 1]). size mismatch for to_rgbs.3.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for to_rgbs.3.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for to_rgbs.4.conv.weight: copying a param with shape torch.Size([1, 3, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 256, 1, 1]). size mismatch for to_rgbs.4.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]). size mismatch for to_rgbs.4.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for to_rgbs.5.conv.weight: copying a param with shape torch.Size([1, 3, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 128, 1, 1]). size mismatch for to_rgbs.5.conv.modulation.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]). size mismatch for to_rgbs.5.conv.modulation.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).

yuchen1984 commented 1 year ago

This thread may be helpful https://github.com/dvschultz/stylegan2-ada-pytorch/issues/6

It seems you may need to get the "n_mlp" and "latent" params right