dorarad / gansformer

Generative Adversarial Transformers
MIT License
1.32k stars 149 forks source link

Can I use FFHQ 1024 pre-trained model with PyTorch? #36

Open dk-hong opened 2 years ago

dk-hong commented 2 years ago

I executed PyTorch_version/loader.py with ffhq-snapshot.pkl and it worked well.

However, I didn't work with ffhq-snapshot-1024.pkl.

How can I resolve this issue?

Error messages are below.

Loading ffhq-snapshot-1024.pkl...
synthesis.b1024.conv_last.weight [32, 32, 3, 3]
Traceback (most recent call last):
  File "loader.py", line 324, in <module>
    convert_network_pickle()
  File "/home/ubuntu/anaconda3/envs/pytorch1.7.1_p37/lib/python3.7/site-packages/click/core.py", line 829, in __call__
    return self.main(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/pytorch1.7.1_p37/lib/python3.7/site-packages/click/core.py", line 782, in main
    rv = self.invoke(ctx)
  File "/home/ubuntu/anaconda3/envs/pytorch1.7.1_p37/lib/python3.7/site-packages/click/core.py", line 1066, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/ubuntu/anaconda3/envs/pytorch1.7.1_p37/lib/python3.7/site-packages/click/core.py", line 610, in invoke
    return callback(*args, **kwargs)
  File "loader.py", line 317, in convert_network_pickle
    data = load_network_pkl(f)
  File "loader.py", line 39, in load_network_pkl
    G = convert_tf_generator(tf_G)
  File "loader.py", line 233, in convert_tf_generator
    r".*\.grid_pos",                                    None,
  File "loader.py", line 83, in _populate_module_params
    assert found
AssertionError

Thank you in advance.

nebojsa-bozanic commented 2 years ago

just add these after line 231: (as it is written for 256, it lacks mappings for the rest of the network) r"synthesis.b512.conv_last.weight", lambda: tf_params[f"synthesis/512x512/ToRGB/extraLayer/weight"].transpose(3, 2, 0, 1), r"synthesis.b512.conv_last.affine.weight", lambda: tf_params[f"synthesis/512x512/ToRGB/extraLayer/mod_weight"].transpose(), r"synthesis.b512.conv_last.affine.bias", lambda: tf_params[f"synthesis/512x512/ToRGB/extraLayer/mod_bias"] + 1, r"synthesis.b1024.conv_last.weight", lambda: tf_params[f"synthesis/1024x1024/ToRGB/extraLayer/weight"].transpose(3, 2, 0, 1), r"synthesis.b1024.conv_last.affine.weight", lambda: tf_params[f"synthesis/1024x1024/ToRGB/extraLayer/mod_weight"].transpose(), r"synthesis.b1024.conv_last.affine.bias", lambda: tf_params[f"synthesis/1024x1024/ToRGB/extraLayer/mod_bias"] + 1,

Cheers