rinongal / StyleGAN-nada

http://stylegan-nada.github.io/
MIT License
1.15k stars 146 forks source link

How can we re-train the frozen generator on a custom dataset? #25

Closed romesa-khan closed 2 years ago

romesa-khan commented 2 years ago

Hi!

Thanks for putting out the great work! I am interested in training the frozen generator on a custom dataset. Could you please guide me, or kindly share the training code? Would sincerely appreciate any help.

Thanks!

rinongal commented 2 years ago

Hi,

Thank you for your interest in our work!

How big is your custom dataset? This project is mostly focused on training without data, but if you tell me roughly how much data you have (10? 100? 5000 images?), maybe I could point you at projects that are a good fit?

romesa-khan commented 2 years ago

Thanks for the prompt reply! I have ~1200 images.

rinongal commented 2 years ago

With 1200 you could try training a StyleGAN-ADA model. Just make sure you either train with the Tensorflow version of ADA or with the Rosinality pytorch implementation. The conversion script we have in this repo (and in the Colab notebook) doesn't really work with models from the official StyleGAN-ADA pytorch version. You might also have better luck if you fine tune from an existing model rather than train from scratch.

I think this youtube channel has some guides on training with new datasets: https://www.youtube.com/channel/UCaZuPdmZ380SFUMKHVsv_AA

romesa-khan commented 2 years ago

Hi,

I trained StyleGAN-ADA on my dataset (resolution 256x256), using the Tensorflow version of ADA. However, unfortunately, the convert_weight.py script is not working for me, when I try convert Tensorflow weights to Pytorch. I get the following error:

Traceback (most recent call last):
  File "/content/gdrive/MyDrive/ML/Code/StyleGAN_NADA/stylegan_nada/convert_weight.py", line 247, in <module>
    state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp)
  File "/content/gdrive/MyDrive/ML/Code/StyleGAN_NADA/stylegan_nada/convert_weight.py", line 160, in fill_statedict
    convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"),
  File "/content/gdrive/MyDrive/ML/Code/StyleGAN_NADA/stylegan_nada/convert_weight.py", line 100, in update
    raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}")
ValueError: Shape mismatch: torch.Size([1, 3, 256, 1, 1]) vs torch.Size([1, 3, 512, 1, 1])

Any clues for resolving this would be really appreciated. Thanks!

rinongal commented 2 years ago

Hi,

Could you try running the convert script with --channel_multiplier 1? (instead of the default which is 2).

romesa-khan commented 2 years ago

Thanks a lot for the help! Setting --channel_multiplier 1 in the convert_weight.py script resolved the above error. But now I get the following error when I try to train a text-guided StyleGAN-NADA model, using the converted weights as my frozen generator.

Loading base models...
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-d9c9711b741e> in <module>()
     69 
     70 print("Loading base models...")
---> 71 net = ZSSGAN(args)
     72 print("Models loaded! Starting training...")
     73 

2 frames
/content/gdrive/MyDrive/ML/Code/StyleGAN_NADA/stylegan_nada/ZSSGAN/model/ZSSGAN.py in __init__(self, args)
    154 
    155         # Set up frozen (source) generator
--> 156         self.generator_frozen = SG2Generator(args.frozen_gen_ckpt, img_size=args.size).to(self.device)
    157         self.generator_frozen.freeze_layers()
    158         self.generator_frozen.eval()

/content/gdrive/MyDrive/ML/Code/StyleGAN_NADA/stylegan_nada/ZSSGAN/model/ZSSGAN.py in __init__(self, checkpoint_path, latent_size, map_layers, img_size, channel_multiplier, device)
     29         checkpoint = torch.load(checkpoint_path, map_location=device)
     30 
---> 31         self.generator.load_state_dict(checkpoint["g_ema"], strict=True)
     32 
     33         with torch.no_grad():

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1481         if len(error_msgs) > 0:
   1482             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1483                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1484         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1485 

RuntimeError: Error(s) in loading state_dict for Generator:
    Missing key(s) in state_dict: "style.3.weight", "style.3.bias", "style.4.weight", "style.4.bias", "style.5.weight", "style.5.bias", "style.6.weight", "style.6.bias", "style.7.weight", "style.7.bias", "style.8.weight", "style.8.bias". 
    size mismatch for convs.6.conv.weight: copying a param with shape torch.Size([1, 256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 512, 512, 3, 3]).
    size mismatch for convs.6.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
    size mismatch for convs.7.conv.weight: copying a param with shape torch.Size([1, 256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 512, 512, 3, 3]).
    size mismatch for convs.7.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
    size mismatch for convs.7.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
    size mismatch for convs.7.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
    size mismatch for convs.8.conv.weight: copying a param with shape torch.Size([1, 128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 512, 3, 3]).
    size mismatch for convs.8.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
    size mismatch for convs.8.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
    size mismatch for convs.8.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
    size mismatch for convs.9.conv.weight: copying a param with shape torch.Size([1, 128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 256, 3, 3]).
    size mismatch for convs.9.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
    size mismatch for convs.9.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
    size mismatch for convs.9.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
    size mismatch for convs.10.conv.weight: copying a param with shape torch.Size([1, 64, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 256, 3, 3]).
    size mismatch for convs.10.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
    size mismatch for convs.10.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
    size mismatch for convs.10.activate.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
    size mismatch for convs.11.conv.weight: copying a param with shape torch.Size([1, 64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 128, 3, 3]).
    size mismatch for convs.11.conv.modulation.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
    size mismatch for convs.11.conv.modulation.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
    size mismatch for convs.11.activate.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
    size mismatch for to_rgbs.3.conv.weight: copying a param with shape torch.Size([1, 3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 512, 1, 1]).
    size mismatch for to_rgbs.3.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
    size mismatch for to_rgbs.3.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
    size mismatch for to_rgbs.4.conv.weight: copying a param with shape torch.Size([1, 3, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 256, 1, 1]).
    size mismatch for to_rgbs.4.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
    size mismatch for to_rgbs.4.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
    size mismatch for to_rgbs.5.conv.weight: copying a param with shape torch.Size([1, 3, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 128, 1, 1]).
    size mismatch for to_rgbs.5.conv.modulation.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
    size mismatch for to_rgbs.5.conv.modulation.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
rinongal commented 2 years ago

It's the same channel_multiplier issue. However, I see that I was not currently passing the channel_multiplier argument to the generator constructor.

I just pushed an update that fixes this. Just grab the new version add the same --channel_multiplier 1 argument when you call train.py

romesa-khan commented 2 years ago

Thanks I pulled the changes, and passed the --channel_multiplier 1 argument when calling train.py. Getting the following error:

Loading base models...
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-f7063a26894a> in <module>()
     70 
     71 print("Loading base models...")
---> 72 net = ZSSGAN(args)
     73 print("Models loaded! Starting training...")
     74 

2 frames
/content/gdrive/MyDrive/ML/Code/StyleGAN_NADA/stylegan_nada/ZSSGAN/model/ZSSGAN.py in __init__(self, args)
    154 
    155         # Set up frozen (source) generator
--> 156         self.generator_frozen = SG2Generator(args.frozen_gen_ckpt, img_size=args.size, channel_multiplier=args.channel_multiplier).to(self.device)
    157         self.generator_frozen.freeze_layers()
    158         self.generator_frozen.eval()

/content/gdrive/MyDrive/ML/Code/StyleGAN_NADA/stylegan_nada/ZSSGAN/model/ZSSGAN.py in __init__(self, checkpoint_path, latent_size, map_layers, img_size, channel_multiplier, device)
     29         checkpoint = torch.load(checkpoint_path, map_location=device)
     30 
---> 31         self.generator.load_state_dict(checkpoint["g_ema"], strict=True)
     32 
     33         with torch.no_grad():

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1481         if len(error_msgs) > 0:
   1482             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1483                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1484         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1485 

RuntimeError: Error(s) in loading state_dict for Generator:
    Missing key(s) in state_dict: "style.3.weight", "style.3.bias", "style.4.weight", "style.4.bias", "style.5.weight", "style.5.bias", "style.6.weight", "style.6.bias", "style.7.weight", "style.7.bias", "style.8.weight", "style.8.bias". 
rinongal commented 2 years ago

Looks like your model is also using 2 mapping network layers instead of the standard 8.

Can you try modifying line 22 of ZSSGAN/model/ZSSGAN.py so that the default value for map_layers is 2 and not 8?

Basically change it so it reads: def __init__(self, checkpoint_path, latent_size=512, map_layers=2, img_size=256, channel_multiplier=2, device='cuda:0'):

If that works for you, I'll see about adding the option as a command line argument so you don't have to play with the code itself.

romesa-khan commented 2 years ago

I set the value of map_layers to 2 and it works now! Thanks a lot for all the help!

rinongal commented 2 years ago

No problem :) Keep in mind that if you later want to load this model with rosinality's code in order to generate more images, you'll need to similarly set channel_multiplier and the number of mapping layers (I believe his code calls it n_mlp).

If that's it, I'll close the issue for now. Feel free to re-open it or start a new issue if you need additional help!

snow1929 commented 1 year ago

@rinongal Hi, thank for great release. I face the error in my custom dataset. After add the arg "--channel_multiplier 1" when I call train.py. The error still happened. I have no ideal how to fix it.

Initializing networks...
Traceback (most recent call last):
  File "train.py", line 161, in <module>
    train(args)
  File "train.py", line 58, in train
    net = ZSSGAN(args)
  File "/mnt/StyleGAN-nada-main/ZSSGAN/model/ZSSGAN.py", line 275, in __init__
    self.generator_frozen = SG2Generator(args.frozen_gen_ckpt, img_size=args.size).to(self.device)
  File "/mnt/StyleGAN-nada-main/ZSSGAN/model/ZSSGAN.py", line 136, in __init__
    self.generator.load_state_dict(checkpoint["g_ema"], strict=True)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 1052, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
        size mismatch for convs.6.conv.weight: copying a param with shape torch.Size([1, 256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 512, 512, 3, 3]).
        size mismatch for convs.6.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for convs.7.conv.weight: copying a param with shape torch.Size([1, 256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 512, 512, 3, 3]).
        size mismatch for convs.7.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
        size mismatch for convs.7.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for convs.7.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for convs.8.conv.weight: copying a param with shape torch.Size([1, 128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 512, 3, 3]).
        size mismatch for convs.8.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
        size mismatch for convs.8.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for convs.8.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
        size mismatch for convs.9.conv.weight: copying a param with shape torch.Size([1, 128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 256, 3, 3]).
        size mismatch for convs.9.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
        size mismatch for convs.9.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
        size mismatch for convs.9.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
        size mismatch for convs.10.conv.weight: copying a param with shape torch.Size([1, 64, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 256, 3, 3]).
        size mismatch for convs.10.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
        size mismatch for convs.10.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
        size mismatch for convs.10.activate.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for convs.11.conv.weight: copying a param with shape torch.Size([1, 64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 128, 3, 3]).
        size mismatch for convs.11.conv.modulation.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
        size mismatch for convs.11.conv.modulation.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for convs.11.activate.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for to_rgbs.3.conv.weight: copying a param with shape torch.Size([1, 3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 512, 1, 1]).
        size mismatch for to_rgbs.3.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
        size mismatch for to_rgbs.3.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for to_rgbs.4.conv.weight: copying a param with shape torch.Size([1, 3, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 256, 1, 1]).
        size mismatch for to_rgbs.4.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
        size mismatch for to_rgbs.4.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
        size mismatch for to_rgbs.5.conv.weight: copying a param with shape torch.Size([1, 3, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 128, 1, 1]).
        size mismatch for to_rgbs.5.conv.modulation.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
        size mismatch for to_rgbs.5.conv.modulation.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
rinongal commented 1 year ago

@snow1929 are you sure you're running with the right --size argument? What is your base model and what is your full train command?

snow1929 commented 1 year ago

@rinongal Here is my full command for training. And "mydataset-network-snapshot-003000.pt" is training with 256x256 images by stylegan2-ada

python train.py --size 256 \
--batch 2 \
--n_sample 4 \
--output_dir /mnt/StyleGAN-nada-main/results/mydataset_256 \
--lr 0.002 \
--frozen_gen_ckpt /mnt/StyleGAN-nada-main/models/mydataset-network-snapshot-003000.pt \
--iter 301 \
--source_class "sunny" \
--target_class "rainy" \
--lambda_direction 1.0 \
--lambda_patch 0.0 \
--lambda_global 0.0 \
--lambda_texture 0.0 \
--lambda_manifold 0.0 \
--phase None \
--auto_layer_k 0 \
--auto_layer_iters 0 \
--auto_layer_batch 8 \
--output_interval 50 \
--clip_models "ViT-B/32" "ViT-B/16" \
--clip_model_weights 1.0 1.0 \
--mixing 0.0 \
--save_interval 50 \
--channel_multiplier 1
rinongal commented 1 year ago

The `--channel_multiplier' arg is currently just being ignored. You need to either modify things here or at the point mentioned in this prior reply

snow1929 commented 1 year ago

I try modifying line 22 of ZSSGAN/model/ZSSGAN.py

def __init__(self, checkpoint_path, latent_size=512, map_layers=2, img_size=256, channel_multiplier=2, device='cuda:0'):

But the error still happened.

rinongal commented 1 year ago

Sorry, looks like the line moved in one of the commits since that reply. It's this line: https://github.com/rinongal/StyleGAN-nada/blob/dc8406ae2173ad186f8f03f3cadf65e613ac9364/ZSSGAN/model/ZSSGAN.py#L126

Basically if you want to change the map layers or the channel multiplier, you need to change the number there, or modify the training script to actually pass and use these arguments.

snow1929 commented 1 year ago

@rinongal Thanks for heling me a lot. I fix the error by this method.

snow1929 commented 1 year ago

@rinongal Execuse me. That, I face the other error when try to generate video. I generate the source latent (mydataset_w_plus.npy) by StyleCLIP.

And According to the error message, I seems that the error is not about latent but the model of checkpoint. But I have no idea that how to fix it.

python generate_videos.py \
> --ckpt /mnt/StyleGAN-nada/models/mydataset-network-snapshot-003000.pt \
> --out_dir /mnt/StyleGAN-nada/video/ \
> --source_latent /mnt/StyleGAN-nada/laten/mydataset_w_plus.npy  \
> --target_latents /mnt/StyleGAN-nada/laten/ \
> --size 256 \
> --channel_multiplier 1
Generating video using checkpoint: /mnt/StyleGAN-nada/models/mydataset-network-snapshot-003000.pt
Traceback (most recent call last):
  File "generate_videos.py", line 252, in <module>
    g_ema.load_state_dict(checkpoint['g_ema'])
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 1052, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
        Missing key(s) in state_dict: "style.3.weight", "style.3.bias", "style.4.weight", "style.4.bias", "style.5.weight", "style.5.bias", "style.6.weight", "style.6.bias", "style.7.weight", "style.7.bias", "style.8.weight", "style.8.bias".
rinongal commented 1 year ago

@snow1929 Looks like your model has 2 mapping network layers and you're trying to load into a model that expects 8. Try to adjust it here: https://github.com/rinongal/StyleGAN-nada/blob/dc8406ae2173ad186f8f03f3cadf65e613ac9364/ZSSGAN/generate_videos.py#L236 or promote it to a full arg and pass that when running the script.

snow1929 commented 1 year ago

@rinongal what kind of model you used to translate the image to laten space ?ReStyle or StyleFlow? would you mind to share the detail of cmd code ? I face trouble on this part. please. 😞 😞 😞