Closed romesa-khan closed 2 years ago
Hi,
Thank you for your interest in our work!
How big is your custom dataset? This project is mostly focused on training without data, but if you tell me roughly how much data you have (10? 100? 5000 images?), maybe I could point you at projects that are a good fit?
Thanks for the prompt reply! I have ~1200 images.
With 1200 you could try training a StyleGAN-ADA model. Just make sure you either train with the Tensorflow version of ADA or with the Rosinality pytorch implementation. The conversion script we have in this repo (and in the Colab notebook) doesn't really work with models from the official StyleGAN-ADA pytorch version. You might also have better luck if you fine tune from an existing model rather than train from scratch.
I think this youtube channel has some guides on training with new datasets: https://www.youtube.com/channel/UCaZuPdmZ380SFUMKHVsv_AA
Hi,
I trained StyleGAN-ADA on my dataset (resolution 256x256), using the Tensorflow version of ADA. However, unfortunately, the convert_weight.py script is not working for me, when I try convert Tensorflow weights to Pytorch. I get the following error:
Traceback (most recent call last):
File "/content/gdrive/MyDrive/ML/Code/StyleGAN_NADA/stylegan_nada/convert_weight.py", line 247, in <module>
state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp)
File "/content/gdrive/MyDrive/ML/Code/StyleGAN_NADA/stylegan_nada/convert_weight.py", line 160, in fill_statedict
convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"),
File "/content/gdrive/MyDrive/ML/Code/StyleGAN_NADA/stylegan_nada/convert_weight.py", line 100, in update
raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}")
ValueError: Shape mismatch: torch.Size([1, 3, 256, 1, 1]) vs torch.Size([1, 3, 512, 1, 1])
Any clues for resolving this would be really appreciated. Thanks!
Hi,
Could you try running the convert script with --channel_multiplier 1
? (instead of the default which is 2).
Thanks a lot for the help!
Setting --channel_multiplier 1
in the convert_weight.py script resolved the above error.
But now I get the following error when I try to train a text-guided StyleGAN-NADA model, using the converted weights as my frozen generator.
Loading base models...
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-4-d9c9711b741e> in <module>()
69
70 print("Loading base models...")
---> 71 net = ZSSGAN(args)
72 print("Models loaded! Starting training...")
73
2 frames
/content/gdrive/MyDrive/ML/Code/StyleGAN_NADA/stylegan_nada/ZSSGAN/model/ZSSGAN.py in __init__(self, args)
154
155 # Set up frozen (source) generator
--> 156 self.generator_frozen = SG2Generator(args.frozen_gen_ckpt, img_size=args.size).to(self.device)
157 self.generator_frozen.freeze_layers()
158 self.generator_frozen.eval()
/content/gdrive/MyDrive/ML/Code/StyleGAN_NADA/stylegan_nada/ZSSGAN/model/ZSSGAN.py in __init__(self, checkpoint_path, latent_size, map_layers, img_size, channel_multiplier, device)
29 checkpoint = torch.load(checkpoint_path, map_location=device)
30
---> 31 self.generator.load_state_dict(checkpoint["g_ema"], strict=True)
32
33 with torch.no_grad():
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1481 if len(error_msgs) > 0:
1482 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1483 self.__class__.__name__, "\n\t".join(error_msgs)))
1484 return _IncompatibleKeys(missing_keys, unexpected_keys)
1485
RuntimeError: Error(s) in loading state_dict for Generator:
Missing key(s) in state_dict: "style.3.weight", "style.3.bias", "style.4.weight", "style.4.bias", "style.5.weight", "style.5.bias", "style.6.weight", "style.6.bias", "style.7.weight", "style.7.bias", "style.8.weight", "style.8.bias".
size mismatch for convs.6.conv.weight: copying a param with shape torch.Size([1, 256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 512, 512, 3, 3]).
size mismatch for convs.6.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for convs.7.conv.weight: copying a param with shape torch.Size([1, 256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 512, 512, 3, 3]).
size mismatch for convs.7.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for convs.7.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for convs.7.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for convs.8.conv.weight: copying a param with shape torch.Size([1, 128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 512, 3, 3]).
size mismatch for convs.8.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for convs.8.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for convs.8.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for convs.9.conv.weight: copying a param with shape torch.Size([1, 128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 256, 3, 3]).
size mismatch for convs.9.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
size mismatch for convs.9.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for convs.9.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for convs.10.conv.weight: copying a param with shape torch.Size([1, 64, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 256, 3, 3]).
size mismatch for convs.10.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
size mismatch for convs.10.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for convs.10.activate.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
size mismatch for convs.11.conv.weight: copying a param with shape torch.Size([1, 64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 128, 3, 3]).
size mismatch for convs.11.conv.modulation.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
size mismatch for convs.11.conv.modulation.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
size mismatch for convs.11.activate.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
size mismatch for to_rgbs.3.conv.weight: copying a param with shape torch.Size([1, 3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 512, 1, 1]).
size mismatch for to_rgbs.3.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for to_rgbs.3.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for to_rgbs.4.conv.weight: copying a param with shape torch.Size([1, 3, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 256, 1, 1]).
size mismatch for to_rgbs.4.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
size mismatch for to_rgbs.4.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for to_rgbs.5.conv.weight: copying a param with shape torch.Size([1, 3, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 128, 1, 1]).
size mismatch for to_rgbs.5.conv.modulation.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
size mismatch for to_rgbs.5.conv.modulation.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
It's the same channel_multiplier issue. However, I see that I was not currently passing the channel_multiplier argument to the generator constructor.
I just pushed an update that fixes this. Just grab the new version add the same --channel_multiplier 1
argument when you call train.py
Thanks
I pulled the changes, and passed the --channel_multiplier 1
argument when calling train.py.
Getting the following error:
Loading base models...
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-5-f7063a26894a> in <module>()
70
71 print("Loading base models...")
---> 72 net = ZSSGAN(args)
73 print("Models loaded! Starting training...")
74
2 frames
/content/gdrive/MyDrive/ML/Code/StyleGAN_NADA/stylegan_nada/ZSSGAN/model/ZSSGAN.py in __init__(self, args)
154
155 # Set up frozen (source) generator
--> 156 self.generator_frozen = SG2Generator(args.frozen_gen_ckpt, img_size=args.size, channel_multiplier=args.channel_multiplier).to(self.device)
157 self.generator_frozen.freeze_layers()
158 self.generator_frozen.eval()
/content/gdrive/MyDrive/ML/Code/StyleGAN_NADA/stylegan_nada/ZSSGAN/model/ZSSGAN.py in __init__(self, checkpoint_path, latent_size, map_layers, img_size, channel_multiplier, device)
29 checkpoint = torch.load(checkpoint_path, map_location=device)
30
---> 31 self.generator.load_state_dict(checkpoint["g_ema"], strict=True)
32
33 with torch.no_grad():
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1481 if len(error_msgs) > 0:
1482 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1483 self.__class__.__name__, "\n\t".join(error_msgs)))
1484 return _IncompatibleKeys(missing_keys, unexpected_keys)
1485
RuntimeError: Error(s) in loading state_dict for Generator:
Missing key(s) in state_dict: "style.3.weight", "style.3.bias", "style.4.weight", "style.4.bias", "style.5.weight", "style.5.bias", "style.6.weight", "style.6.bias", "style.7.weight", "style.7.bias", "style.8.weight", "style.8.bias".
Looks like your model is also using 2 mapping network layers instead of the standard 8.
Can you try modifying line 22 of ZSSGAN/model/ZSSGAN.py
so that the default value for map_layers
is 2 and not 8?
Basically change it so it reads:
def __init__(self, checkpoint_path, latent_size=512, map_layers=2, img_size=256, channel_multiplier=2, device='cuda:0'):
If that works for you, I'll see about adding the option as a command line argument so you don't have to play with the code itself.
I set the value of map_layers
to 2 and it works now!
Thanks a lot for all the help!
No problem :) Keep in mind that if you later want to load this model with rosinality's code in order to generate more images, you'll need to similarly set channel_multiplier
and the number of mapping layers (I believe his code calls it n_mlp
).
If that's it, I'll close the issue for now. Feel free to re-open it or start a new issue if you need additional help!
@rinongal Hi, thank for great release. I face the error in my custom dataset. After add the arg "--channel_multiplier 1" when I call train.py. The error still happened. I have no ideal how to fix it.
Initializing networks...
Traceback (most recent call last):
File "train.py", line 161, in <module>
train(args)
File "train.py", line 58, in train
net = ZSSGAN(args)
File "/mnt/StyleGAN-nada-main/ZSSGAN/model/ZSSGAN.py", line 275, in __init__
self.generator_frozen = SG2Generator(args.frozen_gen_ckpt, img_size=args.size).to(self.device)
File "/mnt/StyleGAN-nada-main/ZSSGAN/model/ZSSGAN.py", line 136, in __init__
self.generator.load_state_dict(checkpoint["g_ema"], strict=True)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 1052, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
size mismatch for convs.6.conv.weight: copying a param with shape torch.Size([1, 256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 512, 512, 3, 3]).
size mismatch for convs.6.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for convs.7.conv.weight: copying a param with shape torch.Size([1, 256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 512, 512, 3, 3]).
size mismatch for convs.7.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for convs.7.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for convs.7.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for convs.8.conv.weight: copying a param with shape torch.Size([1, 128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 512, 3, 3]).
size mismatch for convs.8.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for convs.8.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for convs.8.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for convs.9.conv.weight: copying a param with shape torch.Size([1, 128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 256, 3, 3]).
size mismatch for convs.9.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
size mismatch for convs.9.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for convs.9.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for convs.10.conv.weight: copying a param with shape torch.Size([1, 64, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 256, 3, 3]).
size mismatch for convs.10.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
size mismatch for convs.10.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for convs.10.activate.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
size mismatch for convs.11.conv.weight: copying a param with shape torch.Size([1, 64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 128, 3, 3]).
size mismatch for convs.11.conv.modulation.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
size mismatch for convs.11.conv.modulation.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
size mismatch for convs.11.activate.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
size mismatch for to_rgbs.3.conv.weight: copying a param with shape torch.Size([1, 3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 512, 1, 1]).
size mismatch for to_rgbs.3.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for to_rgbs.3.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for to_rgbs.4.conv.weight: copying a param with shape torch.Size([1, 3, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 256, 1, 1]).
size mismatch for to_rgbs.4.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
size mismatch for to_rgbs.4.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for to_rgbs.5.conv.weight: copying a param with shape torch.Size([1, 3, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 128, 1, 1]).
size mismatch for to_rgbs.5.conv.modulation.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
size mismatch for to_rgbs.5.conv.modulation.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
@snow1929 are you sure you're running with the right --size argument? What is your base model and what is your full train command?
@rinongal Here is my full command for training. And "mydataset-network-snapshot-003000.pt" is training with 256x256 images by stylegan2-ada
python train.py --size 256 \
--batch 2 \
--n_sample 4 \
--output_dir /mnt/StyleGAN-nada-main/results/mydataset_256 \
--lr 0.002 \
--frozen_gen_ckpt /mnt/StyleGAN-nada-main/models/mydataset-network-snapshot-003000.pt \
--iter 301 \
--source_class "sunny" \
--target_class "rainy" \
--lambda_direction 1.0 \
--lambda_patch 0.0 \
--lambda_global 0.0 \
--lambda_texture 0.0 \
--lambda_manifold 0.0 \
--phase None \
--auto_layer_k 0 \
--auto_layer_iters 0 \
--auto_layer_batch 8 \
--output_interval 50 \
--clip_models "ViT-B/32" "ViT-B/16" \
--clip_model_weights 1.0 1.0 \
--mixing 0.0 \
--save_interval 50 \
--channel_multiplier 1
The `--channel_multiplier' arg is currently just being ignored. You need to either modify things here or at the point mentioned in this prior reply
I try modifying line 22 of ZSSGAN/model/ZSSGAN.py
def __init__(self, checkpoint_path, latent_size=512, map_layers=2, img_size=256, channel_multiplier=2, device='cuda:0'):
But the error still happened.
Sorry, looks like the line moved in one of the commits since that reply. It's this line: https://github.com/rinongal/StyleGAN-nada/blob/dc8406ae2173ad186f8f03f3cadf65e613ac9364/ZSSGAN/model/ZSSGAN.py#L126
Basically if you want to change the map layers or the channel multiplier, you need to change the number there, or modify the training script to actually pass and use these arguments.
@rinongal Thanks for heling me a lot. I fix the error by this method.
@rinongal Execuse me. That, I face the other error when try to generate video. I generate the source latent (mydataset_w_plus.npy) by StyleCLIP.
And According to the error message, I seems that the error is not about latent but the model of checkpoint. But I have no idea that how to fix it.
python generate_videos.py \
> --ckpt /mnt/StyleGAN-nada/models/mydataset-network-snapshot-003000.pt \
> --out_dir /mnt/StyleGAN-nada/video/ \
> --source_latent /mnt/StyleGAN-nada/laten/mydataset_w_plus.npy \
> --target_latents /mnt/StyleGAN-nada/laten/ \
> --size 256 \
> --channel_multiplier 1
Generating video using checkpoint: /mnt/StyleGAN-nada/models/mydataset-network-snapshot-003000.pt
Traceback (most recent call last):
File "generate_videos.py", line 252, in <module>
g_ema.load_state_dict(checkpoint['g_ema'])
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 1052, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
Missing key(s) in state_dict: "style.3.weight", "style.3.bias", "style.4.weight", "style.4.bias", "style.5.weight", "style.5.bias", "style.6.weight", "style.6.bias", "style.7.weight", "style.7.bias", "style.8.weight", "style.8.bias".
@snow1929 Looks like your model has 2 mapping network layers and you're trying to load into a model that expects 8. Try to adjust it here: https://github.com/rinongal/StyleGAN-nada/blob/dc8406ae2173ad186f8f03f3cadf65e613ac9364/ZSSGAN/generate_videos.py#L236 or promote it to a full arg and pass that when running the script.
Hi!
Thanks for putting out the great work! I am interested in training the frozen generator on a custom dataset. Could you please guide me, or kindly share the training code? Would sincerely appreciate any help.
Thanks!