yuval-alaluf / stylegan3-editing

Official Implementation of "Third Time's the Charm? Image and Video Editing with StyleGAN3" (AIM ECCVW 2022) https://arxiv.org/abs/2201.13433
https://yuval-alaluf.github.io/stylegan3-editing/
MIT License
660 stars 72 forks source link

[Question] How to convert pkl to pt file? #18

Closed HuaZheLei closed 2 years ago

HuaZheLei commented 2 years ago

Thanks for your excellent work!

Describe the problem

I'd like to learn how you convert pkl to pt file. I use pt file you provide to generate images. Code is as follow:

def get_random_image(generator: Generator, truncation_psi: float, seed):
    with torch.no_grad():
        z = torch.from_numpy(np.random.RandomState(seed).randn(1, 512).astype('float32')).to('cuda')
        if hasattr(generator.synthesis, 'input'):
            m = make_transform(translate=(0, 0), angle=0)
            m = np.linalg.inv(m)
            generator.synthesis.input.transform.copy_(torch.from_numpy(m))
        w = generator.mapping(z, None, truncation_psi=truncation_psi)
        img = generator.synthesis(w, noise_mode='const')
        res_image = tensor2im(img[0])
        return res_image, w

And it works well. But when I convert pkl to pt by myself, it appears several errors. The converting code I used is as follow:

import pickle
import sys
from enum import Enum
from pathlib import Path
from typing import Optional

import torch

checkpoint_path = "pretrained_models/stylegan3-t-ffhq-1024x1024.pkl"
print(f"Loading StyleGAN3 generator from path: {checkpoint_path}")
with open(checkpoint_path, "rb") as f:
    decoder = pickle.load(f)['G_ema'].cuda()
print('Loading done!')

state_dict = decoder.state_dict()
torch.save(state_dict, "pretrained_models/stylegan3-t-ffhq-1024x1024.pt")
print('Converting done!')

Then I use stylegan3-t-ffhq-1024x1024.pt to generate images. And the errors are as follow:

Loading StyleGAN3 generator from path: pretrained_models/stylegan3-t-ffhq-1024x1024.pt
Traceback (most recent call last):
  File "/sam/models/stylegan3/model.py", line 61, in _load_checkpoint
    self.decoder.load_state_dict(torch.load(checkpoint_path), strict=True)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
    Missing key(s) in state_dict: "synthesis.L0_36_1024.weight", "synthesis.L0_36_1024.bias", "synthesis.L0_36_1024.magnitude_ema", "synthesis.L0_36_1024.up_filter", "synthesis.L0_36_1024.down_filter", "synthesis.L0_36_1024.affine.weight", "synthesis.L0_36_1024.affine.bias", "synthesis.L1_36_1024.weight", "synthesis.L1_36_1024.bias", "synthesis.L1_36_1024.magnitude_ema", "synthesis.L1_36_1024.up_filter", "synthesis.L1_36_1024.down_filter", "synthesis.L1_36_1024.affine.weight", "synthesis.L1_36_1024.affine.bias", "synthesis.L2_52_1024.weight", "synthesis.L2_52_1024.bias", "synthesis.L2_52_1024.magnitude_ema", "synthesis.L2_52_1024.up_filter", "synthesis.L2_52_1024.down_filter", "synthesis.L2_52_1024.affine.weight", "synthesis.L2_52_1024.affine.bias", "synthesis.L3_52_1024.weight", "synthesis.L3_52_1024.bias", "synthesis.L3_52_1024.magnitude_ema", "synthesis.L3_52_1024.up_filter", "synthesis.L3_52_1024.down_filter", "synthesis.L3_52_1024.affine.weight", "synthesis.L3_52_1024.affine.bias", "synthesis.L4_84_1024.weight", "synthesis.L4_84_1024.bias", "synthesis.L4_84_1024.magnitude_ema", "synthesis.L4_84_1024.up_filter", "synthesis.L4_84_1024.down_filter", "synthesis.L4_84_1024.affine.weight", "synthesis.L4_84_1024.affine.bias", "synthesis.L5_148_1024.weight", "synthesis.L5_148_1024.bias", "synthesis.L5_148_1024.magnitude_ema", "synthesis.L5_148_1024.up_filter", "synthesis.L5_148_1024.down_filter", "synthesis.L5_148_1024.affine.weight", "synthesis.L5_148_1024.affine.bias", "synthesis.L6_148_1024.weight", "synthesis.L6_148_1024.bias", "synthesis.L6_148_1024.magnitude_ema", "synthesis.L6_148_1024.up_filter", "synthesis.L6_148_1024.down_filter", "synthesis.L6_148_1024.affine.weight", "synthesis.L6_148_1024.affine.bias", "synthesis.L7_276_645.weight", "synthesis.L7_276_645.bias", "synthesis.L7_276_645.magnitude_ema", "synthesis.L7_276_645.up_filter", "synthesis.L7_276_645.down_filter", "synthesis.L7_276_645.affine.weight", "synthesis.L7_276_645.affine.bias", "synthesis.L8_276_406.weight", "synthesis.L8_276_406.bias", "synthesis.L8_276_406.magnitude_ema", "synthesis.L8_276_406.up_filter", "synthesis.L8_276_406.down_filter", "synthesis.L8_276_406.affine.weight", "synthesis.L8_276_406.affine.bias", "synthesis.L9_532_256.weight", "synthesis.L9_532_256.bias", "synthesis.L9_532_256.magnitude_ema", "synthesis.L9_532_256.up_filter", "synthesis.L9_532_256.down_filter", "synthesis.L9_532_256.affine.weight", "synthesis.L9_532_256.affine.bias", "synthesis.L10_1044_161.weight", "synthesis.L10_1044_161.bias", "synthesis.L10_1044_161.magnitude_ema", "synthesis.L10_1044_161.up_filter", "synthesis.L10_1044_161.down_filter", "synthesis.L10_1044_161.affine.weight", "synthesis.L10_1044_161.affine.bias", "synthesis.L11_1044_102.weight", "synthesis.L11_1044_102.bias", "synthesis.L11_1044_102.magnitude_ema", "synthesis.L11_1044_102.up_filter", "synthesis.L11_1044_102.down_filter", "synthesis.L11_1044_102.affine.weight", "synthesis.L11_1044_102.affine.bias", "synthesis.L12_1044_64.weight", "synthesis.L12_1044_64.bias", "synthesis.L12_1044_64.magnitude_ema", "synthesis.L12_1044_64.up_filter", "synthesis.L12_1044_64.down_filter", "synthesis.L12_1044_64.affine.weight", "synthesis.L12_1044_64.affine.bias", "synthesis.L13_1024_64.weight", "synthesis.L13_1024_64.bias", "synthesis.L13_1024_64.magnitude_ema", "synthesis.L13_1024_64.up_filter", "synthesis.L13_1024_64.down_filter", "synthesis.L13_1024_64.affine.weight", "synthesis.L13_1024_64.affine.bias".
    Unexpected key(s) in state_dict: "synthesis.L0_36_512.weight", "synthesis.L0_36_512.bias", "synthesis.L0_36_512.magnitude_ema", "synthesis.L0_36_512.up_filter", "synthesis.L0_36_512.down_filter", "synthesis.L0_36_512.affine.weight", "synthesis.L0_36_512.affine.bias", "synthesis.L1_36_512.weight", "synthesis.L1_36_512.bias", "synthesis.L1_36_512.magnitude_ema", "synthesis.L1_36_512.up_filter", "synthesis.L1_36_512.down_filter", "synthesis.L1_36_512.affine.weight", "synthesis.L1_36_512.affine.bias", "synthesis.L2_52_512.weight", "synthesis.L2_52_512.bias", "synthesis.L2_52_512.magnitude_ema", "synthesis.L2_52_512.up_filter", "synthesis.L2_52_512.down_filter", "synthesis.L2_52_512.affine.weight", "synthesis.L2_52_512.affine.bias", "synthesis.L3_52_512.weight", "synthesis.L3_52_512.bias", "synthesis.L3_52_512.magnitude_ema", "synthesis.L3_52_512.up_filter", "synthesis.L3_52_512.down_filter", "synthesis.L3_52_512.affine.weight", "synthesis.L3_52_512.affine.bias", "synthesis.L4_84_512.weight", "synthesis.L4_84_512.bias", "synthesis.L4_84_512.magnitude_ema", "synthesis.L4_84_512.up_filter", "synthesis.L4_84_512.down_filter", "synthesis.L4_84_512.affine.weight", "synthesis.L4_84_512.affine.bias", "synthesis.L5_148_512.weight", "synthesis.L5_148_512.bias", "synthesis.L5_148_512.magnitude_ema", "synthesis.L5_148_512.up_filter", "synthesis.L5_148_512.down_filter", "synthesis.L5_148_512.affine.weight", "synthesis.L5_148_512.affine.bias", "synthesis.L6_148_512.weight", "synthesis.L6_148_512.bias", "synthesis.L6_148_512.magnitude_ema", "synthesis.L6_148_512.up_filter", "synthesis.L6_148_512.down_filter", "synthesis.L6_148_512.affine.weight", "synthesis.L6_148_512.affine.bias", "synthesis.L7_276_323.weight", "synthesis.L7_276_323.bias", "synthesis.L7_276_323.magnitude_ema", "synthesis.L7_276_323.up_filter", "synthesis.L7_276_323.down_filter", "synthesis.L7_276_323.affine.weight", "synthesis.L7_276_323.affine.bias", "synthesis.L8_276_203.weight", "synthesis.L8_276_203.bias", "synthesis.L8_276_203.magnitude_ema", "synthesis.L8_276_203.up_filter", "synthesis.L8_276_203.down_filter", "synthesis.L8_276_203.affine.weight", "synthesis.L8_276_203.affine.bias", "synthesis.L9_532_128.weight", "synthesis.L9_532_128.bias", "synthesis.L9_532_128.magnitude_ema", "synthesis.L9_532_128.up_filter", "synthesis.L9_532_128.down_filter", "synthesis.L9_532_128.affine.weight", "synthesis.L9_532_128.affine.bias", "synthesis.L10_1044_81.weight", "synthesis.L10_1044_81.bias", "synthesis.L10_1044_81.magnitude_ema", "synthesis.L10_1044_81.up_filter", "synthesis.L10_1044_81.down_filter", "synthesis.L10_1044_81.affine.weight", "synthesis.L10_1044_81.affine.bias", "synthesis.L11_1044_51.weight", "synthesis.L11_1044_51.bias", "synthesis.L11_1044_51.magnitude_ema", "synthesis.L11_1044_51.up_filter", "synthesis.L11_1044_51.down_filter", "synthesis.L11_1044_51.affine.weight", "synthesis.L11_1044_51.affine.bias", "synthesis.L12_1044_32.weight", "synthesis.L12_1044_32.bias", "synthesis.L12_1044_32.magnitude_ema", "synthesis.L12_1044_32.up_filter", "synthesis.L12_1044_32.down_filter", "synthesis.L12_1044_32.affine.weight", "synthesis.L12_1044_32.affine.bias", "synthesis.L13_1024_32.weight", "synthesis.L13_1024_32.bias", "synthesis.L13_1024_32.magnitude_ema", "synthesis.L13_1024_32.up_filter", "synthesis.L13_1024_32.down_filter", "synthesis.L13_1024_32.affine.weight", "synthesis.L13_1024_32.affine.bias".
    size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
    size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
    size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
    size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
    size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]).
    size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "gen_images_using_pt.py", line 79, in <module>
    main()
  File "gen_images_using_pt.py", line 47, in main
    generator = SG3Generator(checkpoint_path=args.generator_path).decoder
  File "/sam/models/stylegan3/model.py", line 56, in __init__
    self._load_checkpoint(checkpoint_path)
  File "/sam/models/stylegan3/model.py", line 65, in _load_checkpoint
    self.decoder.load_state_dict(ckpt, strict=False)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
    size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
    size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
    size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
    size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
    size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]).
    size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
yuval-alaluf commented 2 years ago

I am wondering if the error you got is because you are using the t config and the generators we defined are for the r config. Does the same error happen when you try to load a pt of the r config?

HuaZheLei commented 2 years ago

Cool! I did not realize they are different. I will add t config generator in my code. Thanks again.

uselessai commented 2 years ago

Hi HuaZheLei, I am trying to generate images from a .pt model, but I am not sure how to load the model. How can I load the .pt model? Thanks!!

Thanks for your excellent work!

Describe the problem

I'd like to learn how you convert pkl to pt file. I use pt file you provide to generate images. Code is as follow:

def get_random_image(generator: Generator, truncation_psi: float, seed):
    with torch.no_grad():
        z = torch.from_numpy(np.random.RandomState(seed).randn(1, 512).astype('float32')).to('cuda')
        if hasattr(generator.synthesis, 'input'):
            m = make_transform(translate=(0, 0), angle=0)
            m = np.linalg.inv(m)
            generator.synthesis.input.transform.copy_(torch.from_numpy(m))
        w = generator.mapping(z, None, truncation_psi=truncation_psi)
        img = generator.synthesis(w, noise_mode='const')
        res_image = tensor2im(img[0])
        return res_image, w

And it works well. But when I convert pkl to pt by myself, it appears several errors. The converting code I used is as follow:

import pickle
import sys
from enum import Enum
from pathlib import Path
from typing import Optional

import torch

checkpoint_path = "pretrained_models/stylegan3-t-ffhq-1024x1024.pkl"
print(f"Loading StyleGAN3 generator from path: {checkpoint_path}")
with open(checkpoint_path, "rb") as f:
    decoder = pickle.load(f)['G_ema'].cuda()
print('Loading done!')

state_dict = decoder.state_dict()
torch.save(state_dict, "pretrained_models/stylegan3-t-ffhq-1024x1024.pt")
print('Converting done!')

Then I use stylegan3-t-ffhq-1024x1024.pt to generate images. And the errors are as follow:

Loading StyleGAN3 generator from path: pretrained_models/stylegan3-t-ffhq-1024x1024.pt
Traceback (most recent call last):
  File "/sam/models/stylegan3/model.py", line 61, in _load_checkpoint
    self.decoder.load_state_dict(torch.load(checkpoint_path), strict=True)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
  Missing key(s) in state_dict: "synthesis.L0_36_1024.weight", "synthesis.L0_36_1024.bias", "synthesis.L0_36_1024.magnitude_ema", "synthesis.L0_36_1024.up_filter", "synthesis.L0_36_1024.down_filter", "synthesis.L0_36_1024.affine.weight", "synthesis.L0_36_1024.affine.bias", "synthesis.L1_36_1024.weight", "synthesis.L1_36_1024.bias", "synthesis.L1_36_1024.magnitude_ema", "synthesis.L1_36_1024.up_filter", "synthesis.L1_36_1024.down_filter", "synthesis.L1_36_1024.affine.weight", "synthesis.L1_36_1024.affine.bias", "synthesis.L2_52_1024.weight", "synthesis.L2_52_1024.bias", "synthesis.L2_52_1024.magnitude_ema", "synthesis.L2_52_1024.up_filter", "synthesis.L2_52_1024.down_filter", "synthesis.L2_52_1024.affine.weight", "synthesis.L2_52_1024.affine.bias", "synthesis.L3_52_1024.weight", "synthesis.L3_52_1024.bias", "synthesis.L3_52_1024.magnitude_ema", "synthesis.L3_52_1024.up_filter", "synthesis.L3_52_1024.down_filter", "synthesis.L3_52_1024.affine.weight", "synthesis.L3_52_1024.affine.bias", "synthesis.L4_84_1024.weight", "synthesis.L4_84_1024.bias", "synthesis.L4_84_1024.magnitude_ema", "synthesis.L4_84_1024.up_filter", "synthesis.L4_84_1024.down_filter", "synthesis.L4_84_1024.affine.weight", "synthesis.L4_84_1024.affine.bias", "synthesis.L5_148_1024.weight", "synthesis.L5_148_1024.bias", "synthesis.L5_148_1024.magnitude_ema", "synthesis.L5_148_1024.up_filter", "synthesis.L5_148_1024.down_filter", "synthesis.L5_148_1024.affine.weight", "synthesis.L5_148_1024.affine.bias", "synthesis.L6_148_1024.weight", "synthesis.L6_148_1024.bias", "synthesis.L6_148_1024.magnitude_ema", "synthesis.L6_148_1024.up_filter", "synthesis.L6_148_1024.down_filter", "synthesis.L6_148_1024.affine.weight", "synthesis.L6_148_1024.affine.bias", "synthesis.L7_276_645.weight", "synthesis.L7_276_645.bias", "synthesis.L7_276_645.magnitude_ema", "synthesis.L7_276_645.up_filter", "synthesis.L7_276_645.down_filter", "synthesis.L7_276_645.affine.weight", "synthesis.L7_276_645.affine.bias", "synthesis.L8_276_406.weight", "synthesis.L8_276_406.bias", "synthesis.L8_276_406.magnitude_ema", "synthesis.L8_276_406.up_filter", "synthesis.L8_276_406.down_filter", "synthesis.L8_276_406.affine.weight", "synthesis.L8_276_406.affine.bias", "synthesis.L9_532_256.weight", "synthesis.L9_532_256.bias", "synthesis.L9_532_256.magnitude_ema", "synthesis.L9_532_256.up_filter", "synthesis.L9_532_256.down_filter", "synthesis.L9_532_256.affine.weight", "synthesis.L9_532_256.affine.bias", "synthesis.L10_1044_161.weight", "synthesis.L10_1044_161.bias", "synthesis.L10_1044_161.magnitude_ema", "synthesis.L10_1044_161.up_filter", "synthesis.L10_1044_161.down_filter", "synthesis.L10_1044_161.affine.weight", "synthesis.L10_1044_161.affine.bias", "synthesis.L11_1044_102.weight", "synthesis.L11_1044_102.bias", "synthesis.L11_1044_102.magnitude_ema", "synthesis.L11_1044_102.up_filter", "synthesis.L11_1044_102.down_filter", "synthesis.L11_1044_102.affine.weight", "synthesis.L11_1044_102.affine.bias", "synthesis.L12_1044_64.weight", "synthesis.L12_1044_64.bias", "synthesis.L12_1044_64.magnitude_ema", "synthesis.L12_1044_64.up_filter", "synthesis.L12_1044_64.down_filter", "synthesis.L12_1044_64.affine.weight", "synthesis.L12_1044_64.affine.bias", "synthesis.L13_1024_64.weight", "synthesis.L13_1024_64.bias", "synthesis.L13_1024_64.magnitude_ema", "synthesis.L13_1024_64.up_filter", "synthesis.L13_1024_64.down_filter", "synthesis.L13_1024_64.affine.weight", "synthesis.L13_1024_64.affine.bias".
  Unexpected key(s) in state_dict: "synthesis.L0_36_512.weight", "synthesis.L0_36_512.bias", "synthesis.L0_36_512.magnitude_ema", "synthesis.L0_36_512.up_filter", "synthesis.L0_36_512.down_filter", "synthesis.L0_36_512.affine.weight", "synthesis.L0_36_512.affine.bias", "synthesis.L1_36_512.weight", "synthesis.L1_36_512.bias", "synthesis.L1_36_512.magnitude_ema", "synthesis.L1_36_512.up_filter", "synthesis.L1_36_512.down_filter", "synthesis.L1_36_512.affine.weight", "synthesis.L1_36_512.affine.bias", "synthesis.L2_52_512.weight", "synthesis.L2_52_512.bias", "synthesis.L2_52_512.magnitude_ema", "synthesis.L2_52_512.up_filter", "synthesis.L2_52_512.down_filter", "synthesis.L2_52_512.affine.weight", "synthesis.L2_52_512.affine.bias", "synthesis.L3_52_512.weight", "synthesis.L3_52_512.bias", "synthesis.L3_52_512.magnitude_ema", "synthesis.L3_52_512.up_filter", "synthesis.L3_52_512.down_filter", "synthesis.L3_52_512.affine.weight", "synthesis.L3_52_512.affine.bias", "synthesis.L4_84_512.weight", "synthesis.L4_84_512.bias", "synthesis.L4_84_512.magnitude_ema", "synthesis.L4_84_512.up_filter", "synthesis.L4_84_512.down_filter", "synthesis.L4_84_512.affine.weight", "synthesis.L4_84_512.affine.bias", "synthesis.L5_148_512.weight", "synthesis.L5_148_512.bias", "synthesis.L5_148_512.magnitude_ema", "synthesis.L5_148_512.up_filter", "synthesis.L5_148_512.down_filter", "synthesis.L5_148_512.affine.weight", "synthesis.L5_148_512.affine.bias", "synthesis.L6_148_512.weight", "synthesis.L6_148_512.bias", "synthesis.L6_148_512.magnitude_ema", "synthesis.L6_148_512.up_filter", "synthesis.L6_148_512.down_filter", "synthesis.L6_148_512.affine.weight", "synthesis.L6_148_512.affine.bias", "synthesis.L7_276_323.weight", "synthesis.L7_276_323.bias", "synthesis.L7_276_323.magnitude_ema", "synthesis.L7_276_323.up_filter", "synthesis.L7_276_323.down_filter", "synthesis.L7_276_323.affine.weight", "synthesis.L7_276_323.affine.bias", "synthesis.L8_276_203.weight", "synthesis.L8_276_203.bias", "synthesis.L8_276_203.magnitude_ema", "synthesis.L8_276_203.up_filter", "synthesis.L8_276_203.down_filter", "synthesis.L8_276_203.affine.weight", "synthesis.L8_276_203.affine.bias", "synthesis.L9_532_128.weight", "synthesis.L9_532_128.bias", "synthesis.L9_532_128.magnitude_ema", "synthesis.L9_532_128.up_filter", "synthesis.L9_532_128.down_filter", "synthesis.L9_532_128.affine.weight", "synthesis.L9_532_128.affine.bias", "synthesis.L10_1044_81.weight", "synthesis.L10_1044_81.bias", "synthesis.L10_1044_81.magnitude_ema", "synthesis.L10_1044_81.up_filter", "synthesis.L10_1044_81.down_filter", "synthesis.L10_1044_81.affine.weight", "synthesis.L10_1044_81.affine.bias", "synthesis.L11_1044_51.weight", "synthesis.L11_1044_51.bias", "synthesis.L11_1044_51.magnitude_ema", "synthesis.L11_1044_51.up_filter", "synthesis.L11_1044_51.down_filter", "synthesis.L11_1044_51.affine.weight", "synthesis.L11_1044_51.affine.bias", "synthesis.L12_1044_32.weight", "synthesis.L12_1044_32.bias", "synthesis.L12_1044_32.magnitude_ema", "synthesis.L12_1044_32.up_filter", "synthesis.L12_1044_32.down_filter", "synthesis.L12_1044_32.affine.weight", "synthesis.L12_1044_32.affine.bias", "synthesis.L13_1024_32.weight", "synthesis.L13_1024_32.bias", "synthesis.L13_1024_32.magnitude_ema", "synthesis.L13_1024_32.up_filter", "synthesis.L13_1024_32.down_filter", "synthesis.L13_1024_32.affine.weight", "synthesis.L13_1024_32.affine.bias".
  size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
  size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
  size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
  size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
  size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]).
  size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "gen_images_using_pt.py", line 79, in <module>
    main()
  File "gen_images_using_pt.py", line 47, in main
    generator = SG3Generator(checkpoint_path=args.generator_path).decoder
  File "/sam/models/stylegan3/model.py", line 56, in __init__
    self._load_checkpoint(checkpoint_path)
  File "/sam/models/stylegan3/model.py", line 65, in _load_checkpoint
    self.decoder.load_state_dict(ckpt, strict=False)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
  size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
  size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
  size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
  size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
  size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]).
  size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
HuaZheLei commented 2 years ago

Hi HuaZheLei, I am trying to generate images from a .pt model, but I am not sure how to load the model. How can I load the .pt model? Thanks!!

Thanks for your excellent work!

Describe the problem

I'd like to learn how you convert pkl to pt file. I use pt file you provide to generate images. Code is as follow:

def get_random_image(generator: Generator, truncation_psi: float, seed):
    with torch.no_grad():
        z = torch.from_numpy(np.random.RandomState(seed).randn(1, 512).astype('float32')).to('cuda')
        if hasattr(generator.synthesis, 'input'):
            m = make_transform(translate=(0, 0), angle=0)
            m = np.linalg.inv(m)
            generator.synthesis.input.transform.copy_(torch.from_numpy(m))
        w = generator.mapping(z, None, truncation_psi=truncation_psi)
        img = generator.synthesis(w, noise_mode='const')
        res_image = tensor2im(img[0])
        return res_image, w

And it works well. But when I convert pkl to pt by myself, it appears several errors. The converting code I used is as follow:

import pickle
import sys
from enum import Enum
from pathlib import Path
from typing import Optional

import torch

checkpoint_path = "pretrained_models/stylegan3-t-ffhq-1024x1024.pkl"
print(f"Loading StyleGAN3 generator from path: {checkpoint_path}")
with open(checkpoint_path, "rb") as f:
    decoder = pickle.load(f)['G_ema'].cuda()
print('Loading done!')

state_dict = decoder.state_dict()
torch.save(state_dict, "pretrained_models/stylegan3-t-ffhq-1024x1024.pt")
print('Converting done!')

Then I use stylegan3-t-ffhq-1024x1024.pt to generate images. And the errors are as follow:

Loading StyleGAN3 generator from path: pretrained_models/stylegan3-t-ffhq-1024x1024.pt
Traceback (most recent call last):
  File "/sam/models/stylegan3/model.py", line 61, in _load_checkpoint
    self.decoder.load_state_dict(torch.load(checkpoint_path), strict=True)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
    Missing key(s) in state_dict: "synthesis.L0_36_1024.weight", "synthesis.L0_36_1024.bias", "synthesis.L0_36_1024.magnitude_ema", "synthesis.L0_36_1024.up_filter", "synthesis.L0_36_1024.down_filter", "synthesis.L0_36_1024.affine.weight", "synthesis.L0_36_1024.affine.bias", "synthesis.L1_36_1024.weight", "synthesis.L1_36_1024.bias", "synthesis.L1_36_1024.magnitude_ema", "synthesis.L1_36_1024.up_filter", "synthesis.L1_36_1024.down_filter", "synthesis.L1_36_1024.affine.weight", "synthesis.L1_36_1024.affine.bias", "synthesis.L2_52_1024.weight", "synthesis.L2_52_1024.bias", "synthesis.L2_52_1024.magnitude_ema", "synthesis.L2_52_1024.up_filter", "synthesis.L2_52_1024.down_filter", "synthesis.L2_52_1024.affine.weight", "synthesis.L2_52_1024.affine.bias", "synthesis.L3_52_1024.weight", "synthesis.L3_52_1024.bias", "synthesis.L3_52_1024.magnitude_ema", "synthesis.L3_52_1024.up_filter", "synthesis.L3_52_1024.down_filter", "synthesis.L3_52_1024.affine.weight", "synthesis.L3_52_1024.affine.bias", "synthesis.L4_84_1024.weight", "synthesis.L4_84_1024.bias", "synthesis.L4_84_1024.magnitude_ema", "synthesis.L4_84_1024.up_filter", "synthesis.L4_84_1024.down_filter", "synthesis.L4_84_1024.affine.weight", "synthesis.L4_84_1024.affine.bias", "synthesis.L5_148_1024.weight", "synthesis.L5_148_1024.bias", "synthesis.L5_148_1024.magnitude_ema", "synthesis.L5_148_1024.up_filter", "synthesis.L5_148_1024.down_filter", "synthesis.L5_148_1024.affine.weight", "synthesis.L5_148_1024.affine.bias", "synthesis.L6_148_1024.weight", "synthesis.L6_148_1024.bias", "synthesis.L6_148_1024.magnitude_ema", "synthesis.L6_148_1024.up_filter", "synthesis.L6_148_1024.down_filter", "synthesis.L6_148_1024.affine.weight", "synthesis.L6_148_1024.affine.bias", "synthesis.L7_276_645.weight", "synthesis.L7_276_645.bias", "synthesis.L7_276_645.magnitude_ema", "synthesis.L7_276_645.up_filter", "synthesis.L7_276_645.down_filter", "synthesis.L7_276_645.affine.weight", "synthesis.L7_276_645.affine.bias", "synthesis.L8_276_406.weight", "synthesis.L8_276_406.bias", "synthesis.L8_276_406.magnitude_ema", "synthesis.L8_276_406.up_filter", "synthesis.L8_276_406.down_filter", "synthesis.L8_276_406.affine.weight", "synthesis.L8_276_406.affine.bias", "synthesis.L9_532_256.weight", "synthesis.L9_532_256.bias", "synthesis.L9_532_256.magnitude_ema", "synthesis.L9_532_256.up_filter", "synthesis.L9_532_256.down_filter", "synthesis.L9_532_256.affine.weight", "synthesis.L9_532_256.affine.bias", "synthesis.L10_1044_161.weight", "synthesis.L10_1044_161.bias", "synthesis.L10_1044_161.magnitude_ema", "synthesis.L10_1044_161.up_filter", "synthesis.L10_1044_161.down_filter", "synthesis.L10_1044_161.affine.weight", "synthesis.L10_1044_161.affine.bias", "synthesis.L11_1044_102.weight", "synthesis.L11_1044_102.bias", "synthesis.L11_1044_102.magnitude_ema", "synthesis.L11_1044_102.up_filter", "synthesis.L11_1044_102.down_filter", "synthesis.L11_1044_102.affine.weight", "synthesis.L11_1044_102.affine.bias", "synthesis.L12_1044_64.weight", "synthesis.L12_1044_64.bias", "synthesis.L12_1044_64.magnitude_ema", "synthesis.L12_1044_64.up_filter", "synthesis.L12_1044_64.down_filter", "synthesis.L12_1044_64.affine.weight", "synthesis.L12_1044_64.affine.bias", "synthesis.L13_1024_64.weight", "synthesis.L13_1024_64.bias", "synthesis.L13_1024_64.magnitude_ema", "synthesis.L13_1024_64.up_filter", "synthesis.L13_1024_64.down_filter", "synthesis.L13_1024_64.affine.weight", "synthesis.L13_1024_64.affine.bias".
    Unexpected key(s) in state_dict: "synthesis.L0_36_512.weight", "synthesis.L0_36_512.bias", "synthesis.L0_36_512.magnitude_ema", "synthesis.L0_36_512.up_filter", "synthesis.L0_36_512.down_filter", "synthesis.L0_36_512.affine.weight", "synthesis.L0_36_512.affine.bias", "synthesis.L1_36_512.weight", "synthesis.L1_36_512.bias", "synthesis.L1_36_512.magnitude_ema", "synthesis.L1_36_512.up_filter", "synthesis.L1_36_512.down_filter", "synthesis.L1_36_512.affine.weight", "synthesis.L1_36_512.affine.bias", "synthesis.L2_52_512.weight", "synthesis.L2_52_512.bias", "synthesis.L2_52_512.magnitude_ema", "synthesis.L2_52_512.up_filter", "synthesis.L2_52_512.down_filter", "synthesis.L2_52_512.affine.weight", "synthesis.L2_52_512.affine.bias", "synthesis.L3_52_512.weight", "synthesis.L3_52_512.bias", "synthesis.L3_52_512.magnitude_ema", "synthesis.L3_52_512.up_filter", "synthesis.L3_52_512.down_filter", "synthesis.L3_52_512.affine.weight", "synthesis.L3_52_512.affine.bias", "synthesis.L4_84_512.weight", "synthesis.L4_84_512.bias", "synthesis.L4_84_512.magnitude_ema", "synthesis.L4_84_512.up_filter", "synthesis.L4_84_512.down_filter", "synthesis.L4_84_512.affine.weight", "synthesis.L4_84_512.affine.bias", "synthesis.L5_148_512.weight", "synthesis.L5_148_512.bias", "synthesis.L5_148_512.magnitude_ema", "synthesis.L5_148_512.up_filter", "synthesis.L5_148_512.down_filter", "synthesis.L5_148_512.affine.weight", "synthesis.L5_148_512.affine.bias", "synthesis.L6_148_512.weight", "synthesis.L6_148_512.bias", "synthesis.L6_148_512.magnitude_ema", "synthesis.L6_148_512.up_filter", "synthesis.L6_148_512.down_filter", "synthesis.L6_148_512.affine.weight", "synthesis.L6_148_512.affine.bias", "synthesis.L7_276_323.weight", "synthesis.L7_276_323.bias", "synthesis.L7_276_323.magnitude_ema", "synthesis.L7_276_323.up_filter", "synthesis.L7_276_323.down_filter", "synthesis.L7_276_323.affine.weight", "synthesis.L7_276_323.affine.bias", "synthesis.L8_276_203.weight", "synthesis.L8_276_203.bias", "synthesis.L8_276_203.magnitude_ema", "synthesis.L8_276_203.up_filter", "synthesis.L8_276_203.down_filter", "synthesis.L8_276_203.affine.weight", "synthesis.L8_276_203.affine.bias", "synthesis.L9_532_128.weight", "synthesis.L9_532_128.bias", "synthesis.L9_532_128.magnitude_ema", "synthesis.L9_532_128.up_filter", "synthesis.L9_532_128.down_filter", "synthesis.L9_532_128.affine.weight", "synthesis.L9_532_128.affine.bias", "synthesis.L10_1044_81.weight", "synthesis.L10_1044_81.bias", "synthesis.L10_1044_81.magnitude_ema", "synthesis.L10_1044_81.up_filter", "synthesis.L10_1044_81.down_filter", "synthesis.L10_1044_81.affine.weight", "synthesis.L10_1044_81.affine.bias", "synthesis.L11_1044_51.weight", "synthesis.L11_1044_51.bias", "synthesis.L11_1044_51.magnitude_ema", "synthesis.L11_1044_51.up_filter", "synthesis.L11_1044_51.down_filter", "synthesis.L11_1044_51.affine.weight", "synthesis.L11_1044_51.affine.bias", "synthesis.L12_1044_32.weight", "synthesis.L12_1044_32.bias", "synthesis.L12_1044_32.magnitude_ema", "synthesis.L12_1044_32.up_filter", "synthesis.L12_1044_32.down_filter", "synthesis.L12_1044_32.affine.weight", "synthesis.L12_1044_32.affine.bias", "synthesis.L13_1024_32.weight", "synthesis.L13_1024_32.bias", "synthesis.L13_1024_32.magnitude_ema", "synthesis.L13_1024_32.up_filter", "synthesis.L13_1024_32.down_filter", "synthesis.L13_1024_32.affine.weight", "synthesis.L13_1024_32.affine.bias".
    size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
    size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
    size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
    size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
    size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]).
    size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "gen_images_using_pt.py", line 79, in <module>
    main()
  File "gen_images_using_pt.py", line 47, in main
    generator = SG3Generator(checkpoint_path=args.generator_path).decoder
  File "/sam/models/stylegan3/model.py", line 56, in __init__
    self._load_checkpoint(checkpoint_path)
  File "/sam/models/stylegan3/model.py", line 65, in _load_checkpoint
    self.decoder.load_state_dict(ckpt, strict=False)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
    size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
    size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
    size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
    size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
    size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]).
    size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).

Hi, I will share my code here.

import os
import argparse
from typing import Tuple, List, Union

import numpy as np
import torch

from models.stylegan3.model import SG3Generator
from models.stylegan3.networks_stylegan3 import Generator
from utils.common import tensor2im

def make_transform(translate: Tuple[float, float], angle: float):
    m = np.eye(3)
    s = np.sin(angle / 360.0 * np.pi * 2)
    c = np.cos(angle / 360.0 * np.pi * 2)
    m[0][0] = c
    m[0][1] = s
    m[0][2] = translate[0]
    m[1][0] = -s
    m[1][1] = c
    m[1][2] = translate[1]
    return m

def main():
    args = parse_args()
    save_dir = args.save_dir
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    generator = SG3Generator(checkpoint_path=args.generator_path).decoder

    for i in range(args.image_numbers):
        print('Generating image for seed %d (%d/%d) ...' % (i, i, args.image_numbers))
        image, latent = get_random_image(generator, truncation_psi=args.truncation_psi, seed=i)
        image.save(os.path.join(save_dir, 'seed' + str(i).zfill(4) + '.png'))

def get_random_image(generator: Generator, truncation_psi: float, seed):
    with torch.no_grad():
        z = torch.from_numpy(np.random.RandomState(seed).randn(1, 512).astype('float32')).to('cuda')
        if hasattr(generator.synthesis, 'input'):
            m = make_transform(translate=(0, 0), angle=0)
            m = np.linalg.inv(m)
            generator.synthesis.input.transform.copy_(torch.from_numpy(m))
        w = generator.mapping(z, None, truncation_psi=truncation_psi)
        img = generator.synthesis(w, noise_mode='const')
        res_image = tensor2im(img[0])
        return res_image, w

Hope it helpful.

uselessai commented 2 years ago

Thank you very much!!!.. you save my day!