lucidrains / stylegan2-pytorch

Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement
https://thispersondoesnotexist.com
MIT License
3.71k stars 586 forks source link

cannot load model when using Contrastive Loss Regularization #227

Closed xuboming8 closed 3 years ago

xuboming8 commented 3 years ago

When I used Contrastive Loss Regularization(--cl-reg), 'stylegan2_pytorch --generate' command did not work after training. It seems that the model cannot be loaded correctly.

continuing from previous epoch - 0 loading from version 1.8.0 unable to load save model. please try downgrading the package to the version specified by the saved model Traceback (most recent call last): File "/home/10301003/anaconda3/envs/pytorch1.6/bin/stylegan2_pytorch", line 8, in sys.exit(main()) File "/home/10301003/anaconda3/envs/pytorch1.6/lib/python3.8/site-packages/stylegan2_pytorch/cli.py", line 187, in main fire.Fire(train_from_folder) File "/home/10301003/anaconda3/envs/pytorch1.6/lib/python3.8/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/home/10301003/anaconda3/envs/pytorch1.6/lib/python3.8/site-packages/fire/core.py", line 466, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/home/10301003/anaconda3/envs/pytorch1.6/lib/python3.8/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/home/10301003/anaconda3/envs/pytorch1.6/lib/python3.8/site-packages/stylegan2_pytorch/cli.py", line 160, in train_from_folder model.load(load_from) File "/home/10301003/anaconda3/envs/pytorch1.6/lib/python3.8/site-packages/stylegan2_pytorch/stylegan2_pytorch.py", line 1394, in load raise e File "/home/10301003/anaconda3/envs/pytorch1.6/lib/python3.8/site-packages/stylegan2_pytorch/stylegan2_pytorch.py", line 1391, in load self.GAN.load_state_dict(load_data['GAN']) File "/home/10301003/anaconda3/envs/pytorch1.6/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1044, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for StyleGAN2: Unexpected key(s) in state_dict: "D_cl.net.net.blocks.0.conv_res.weight", "D_cl.net.net.blocks.0.conv_res.bias", "D_cl.net.net.blocks.0.net.0.weight", "D_cl.net.net.blocks.0.net.0.bias", "D_cl.net.net.blocks.0.net.2.weight", "D_cl.net.net.blocks.0.net.2.bias", "D_cl.net.net.blocks.0.downsample.0.f", "D_cl.net.net.blocks.0.downsample.1.weight", "D_cl.net.net.blocks.0.downsample.1.bias", "D_cl.net.net.blocks.1.conv_res.weight", "D_cl.net.net.blocks.1.conv_res.bias", "D_cl.net.net.blocks.1.net.0.weight", "D_cl.net.net.blocks.1.net.0.bias", "D_cl.net.net.blocks.1.net.2.weight", "D_cl.net.net.blocks.1.net.2.bias", "D_cl.net.net.blocks.1.downsample.0.f", "D_cl.net.net.blocks.1.downsample.1.weight", "D_cl.net.net.blocks.1.downsample.1.bias", "D_cl.net.net.blocks.2.conv_res.weight", "D_cl.net.net.blocks.2.conv_res.bias", "D_cl.net.net.blocks.2.net.0.weight", "D_cl.net.net.blocks.2.net.0.bias", "D_cl.net.net.blocks.2.net.2.weight", "D_cl.net.net.blocks.2.net.2.bias", "D_cl.net.net.blocks.2.downsample.0.f", "D_cl.net.net.blocks.2.downsample.1.weight", "D_cl.net.net.blocks.2.downsample.1.bias", "D_cl.net.net.blocks.3.conv_res.weight", "D_cl.net.net.blocks.3.conv_res.bias", "D_cl.net.net.blocks.3.net.0.weight", "D_cl.net.net.blocks.3.net.0.bias", "D_cl.net.net.blocks.3.net.2.weight", "D_cl.net.net.blocks.3.net.2.bias", "D_cl.net.net.blocks.3.downsample.0.f", "D_cl.net.net.blocks.3.downsample.1.weight", "D_cl.net.net.blocks.3.downsample.1.bias", "D_cl.net.net.blocks.4.conv_res.weight", "D_cl.net.net.blocks.4.conv_res.bias", "D_cl.net.net.blocks.4.net.0.weight", "D_cl.net.net.blocks.4.net.0.bias", "D_cl.net.net.blocks.4.net.2.weight", "D_cl.net.net.blocks.4.net.2.bias", "D_cl.net.net.blocks.4.downsample.0.f", "D_cl.net.net.blocks.4.downsample.1.weight", "D_cl.net.net.blocks.4.downsample.1.bias", "D_cl.net.net.blocks.5.conv_res.weight", "D_cl.net.net.blocks.5.conv_res.bias", "D_cl.net.net.blocks.5.net.0.weight", "D_cl.net.net.blocks.5.net.0.bias", "D_cl.net.net.blocks.5.net.2.weight", "D_cl.net.net.blocks.5.net.2.bias", "D_cl.net.net.blocks.5.downsample.0.f", "D_cl.net.net.blocks.5.downsample.1.weight", "D_cl.net.net.blocks.5.downsample.1.bias", "D_cl.net.net.blocks.6.conv_res.weight", "D_cl.net.net.blocks.6.conv_res.bias", "D_cl.net.net.blocks.6.net.0.weight", "D_cl.net.net.blocks.6.net.0.bias", "D_cl.net.net.blocks.6.net.2.weight", "D_cl.net.net.blocks.6.net.2.bias", "D_cl.net.net.attn_blocks.0.0.fn.fn.to_q.weight", "D_cl.net.net.attn_blocks.0.0.fn.fn.to_kv.net.0.weight", "D_cl.net.net.attn_blocks.0.0.fn.fn.to_kv.net.1.weight", "D_cl.net.net.attn_blocks.0.0.fn.fn.to_out.weight", "D_cl.net.net.attn_blocks.0.0.fn.fn.to_out.bias", "D_cl.net.net.attn_blocks.0.0.fn.norm.g", "D_cl.net.net.attn_blocks.0.0.fn.norm.b", "D_cl.net.net.attn_blocks.0.1.fn.fn.0.weight", "D_cl.net.net.attn_blocks.0.1.fn.fn.0.bias", "D_cl.net.net.attn_blocks.0.1.fn.fn.2.weight", "D_cl.net.net.attn_blocks.0.1.fn.fn.2.bias", "D_cl.net.net.attn_blocks.0.1.fn.norm.g", "D_cl.net.net.attn_blocks.0.1.fn.norm.b", "D_cl.net.net.attn_blocks.1.0.fn.fn.to_q.weight", "D_cl.net.net.attn_blocks.1.0.fn.fn.to_kv.net.0.weight", "D_cl.net.net.attn_blocks.1.0.fn.fn.to_kv.net.1.weight", "D_cl.net.net.attn_blocks.1.0.fn.fn.to_out.weight", "D_cl.net.net.attn_blocks.1.0.fn.fn.to_out.bias", "D_cl.net.net.attn_blocks.1.0.fn.norm.g", "D_cl.net.net.attn_blocks.1.0.fn.norm.b", "D_cl.net.net.attn_blocks.1.1.fn.fn.0.weight", "D_cl.net.net.attn_blocks.1.1.fn.fn.0.bias", "D_cl.net.net.attn_blocks.1.1.fn.fn.2.weight", "D_cl.net.net.attn_blocks.1.1.fn.fn.2.bias", "D_cl.net.net.attn_blocks.1.1.fn.norm.g", "D_cl.net.net.attn_blocks.1.1.fn.norm.b", "D_cl.net.net.final_conv.weight", "D_cl.net.net.final_conv.bias", "D_cl.net.net.to_logit.weight", "D_cl.net.net.to_logit.bias", "D_cl.projection.0.weight", "D_cl.projection.2.weight".

When I deleted (--cl-reg), 'generate' command ran correctly. Does something wrong?

lucidrains commented 3 years ago

you'll have to also invoke '--cl-reg' with '--generate' (but I can fix this so you won't have to later in the week)

have you had some success with this experimental setting?

xuboming8 commented 3 years ago

thanks,it works. I also have a problem in Feature Quantization(--fq-layers [1,2] --fq-dict-size 512)

Traceback (most recent call last): File "/home/xuboming/anaconda3/envs/pytorch1.6.0/bin/stylegan2_pytorch", line 8, in sys.exit(main()) File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/stylegan2_pytorch/cli.py", line 187, in main fire.Fire(train_from_folder) File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/fire/core.py", line 466, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace component = fn(*varargs, kwargs) File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/stylegan2_pytorch/cli.py", line 178, in train_from_folder run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed) File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/stylegan2_pytorch/cli.py", line 52, in run_training model.load(load_from) File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/stylegan2_pytorch/stylegan2_pytorch.py", line 1372, in load self.load_config() File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/stylegan2_pytorch/stylegan2_pytorch.py", line 934, in load_config self.init_GAN() File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/stylegan2_pytorch/stylegan2_pytorch.py", line 907, in init_GAN self.GAN = StyleGAN2(lr = self.lr, lr_mlp = self.lr_mlp, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, fmap_max = self.fmap_max, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, rank = self.rank, *args, *kwargs) File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/stylegan2_pytorch/stylegan2_pytorch.py", line 709, in init self.D_cl = ContrastiveLearner(self.D, image_size, hidden_layer='flatten') File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/contrastive_learner/contrastive_learner.py", line 172, in init self.forward(torch.randn(1, 3, image_size, image_size)) File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/contrastive_learner/contrastive_learner.py", line 216, in forward queries = query_encoder(transform_fn(x)) File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(input, kwargs) File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/contrastive_learner/contrastivelearner.py", line 128, in forward = self.net(x) File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, *kwargs) File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/stylegan2_pytorch/stylegan2pytorch.py", line 681, in forward x, , loss = q_block(x) File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(input, **kwargs) File "/home/xuboming/anaconda3/envs/pytorch1.6.0/lib/python3.8/site-packages/stylegan2_pytorch/stylegan2_pytorch.py", line 117, in forward out, loss = self.fn(x) ValueError: too many values to unpack (expected 2)

Since you said that you have not noticed any dramatic changes in vector quantize,I have to negelect this strategy during training.

lucidrains commented 3 years ago

@xuboming8 fixed the feature quantization bug! :pray: https://github.com/lucidrains/stylegan2-pytorch/releases/tag/1.8.1