rosinality / stylegan2-pytorch

Implementation of Analyzing and Improving the Image Quality of StyleGAN (StyleGAN 2) in PyTorch
MIT License
2.74k stars 623 forks source link

finetune error based on stylegan2-ffhq-config-f.pt #271

Open Ethantequila opened 2 years ago

Ethantequila commented 2 years ago

I wanna do some finetune works based on my own dataset, I create LMDB dataset based on the guideline in the repo:

python prepare_data.py --out LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... DATASET_PATH

and then i want to finetune based on stylegan2-ffhq-config-f.pt(convert from stylegan2-ffhq-config-f.pkl) python train.py \ --ckpt stylegan2-ffhq-config-f.pt \ --channel_multiplier 1 \ --augment \ stylegan2-pytorch/datasets/ynby_lmdb_256

but i got an error like this load model: stylegan2-ffhq-config-f.pt Traceback (most recent call last): File "train.py", line 490, in <module> generator.load_state_dict(ckpt["g"]) File "/opt/packages/automation/miniconda3/envs/sam_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Generator: Unexpected key(s) in state_dict: "convs.12.conv.weight", "convs.12.conv.blur.kernel", "convs.12.conv.modulation.weight", "convs.12.conv.modulation.bias", "convs.12.noise.weight", "convs.12.activate.bias", "convs.13.conv.weight", "convs.13.conv.modulation.weight", "convs.13.conv.modulation.bias", "convs.13.noise.weight", "convs.13.activate.bias", "convs.14.conv.weight", "convs.14.conv.blur.kernel", "convs.14.conv.modulation.weight", "convs.14.conv.modulation.bias", "convs.14.noise.weight", "convs.14.activate.bias", "convs.15.conv.weight", "convs.15.conv.modulation.weight", "convs.15.conv.modulation.bias", "convs.15.noise.weight", "convs.15.activate.bias", "to_rgbs.6.bias", "to_rgbs.6.upsample.kernel", "to_rgbs.6.conv.weight", "to_rgbs.6.conv.modulation.weight", "to_rgbs.6.conv.modulation.bias", "to_rgbs.7.bias", "to_rgbs.7.upsample.kernel", "to_rgbs.7.conv.weight", "to_rgbs.7.conv.modulation.weight", "to_rgbs.7.conv.modulation.bias", "noises.noise_13", "noises.noise_14", "noises.noise_15", "noises.noise_16". size mismatch for convs.6.conv.weight: copying a param with shape torch.Size([1, 512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 512, 3, 3]). size mismatch for convs.6.activate.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for convs.7.conv.weight: copying a param with shape torch.Size([1, 512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 256, 3, 3]). size mismatch for convs.7.conv.modulation.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]). size mismatch for convs.7.conv.modulation.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for convs.7.activate.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for convs.8.conv.weight: copying a param with shape torch.Size([1, 256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 256, 3, 3]). size mismatch for convs.8.conv.modulation.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]). size mismatch for convs.8.conv.modulation.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for convs.8.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). size mismatch for convs.9.conv.weight: copying a param with shape torch.Size([1, 256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 128, 3, 3]). size mismatch for convs.9.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]). size mismatch for convs.9.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). size mismatch for convs.9.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). size mismatch for convs.10.conv.weight: copying a param with shape torch.Size([1, 128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 64, 128, 3, 3]). size mismatch for convs.10.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]). size mismatch for convs.10.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). size mismatch for convs.10.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for convs.11.conv.weight: copying a param with shape torch.Size([1, 128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 64, 64, 3, 3]). size mismatch for convs.11.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]). size mismatch for convs.11.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for convs.11.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for to_rgbs.3.conv.weight: copying a param with shape torch.Size([1, 3, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 256, 1, 1]). size mismatch for to_rgbs.3.conv.modulation.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]). size mismatch for to_rgbs.3.conv.modulation.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for to_rgbs.4.conv.weight: copying a param with shape torch.Size([1, 3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 128, 1, 1]). size mismatch for to_rgbs.4.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]). size mismatch for to_rgbs.4.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). size mismatch for to_rgbs.5.conv.weight: copying a param with shape torch.Size([1, 3, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 64, 1, 1]). size mismatch for to_rgbs.5.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]). size mismatch for to_rgbs.5.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]). do you guys have any idea how to fix this problem. thanks !

rosinality commented 2 years ago

You will need to increase --size argument.

Ethantequila commented 2 years ago

You will need to increase --size argument. Thanks for your reply, I tried add flag --size 256 in the python command. but i got same error python train.py --ckpt stylegan2-ffhq-config-f.pt --channel_multiplier 1 --augment --size 256 /data/YNBY/zhixing/stylegan2-pytorch/datasets/ynby_lmdb_256

RuntimeError: Error(s) in loading state_dict for Generator: Unexpected key(s) in state_dict: "convs.12.conv.weight", "convs.12.conv.blur.kernel", "convs.12.conv.modulation.weight", "convs.12.conv.modulation.bias", "convs.12.noise.weight", "convs.12.activate.bias", "convs.13.conv.weight", "convs.13.conv.modulation.weight", "convs.13.conv.modulation.bias", "convs.13.noise.weight", "convs.13.activate.bias", "convs.14.conv.weight", "convs.14.conv.blur.kernel", "convs.14.conv.modulation.weight", "convs.14.conv.modulation.bias", "convs.14.noise.weight", "convs.14.activate.bias", "convs.15.conv.weight", "convs.15.conv.modulation.weight", "convs.15.conv.modulation.bias", "convs.15.noise.weight", "convs.15.activate.bias", "to_rgbs.6.bias", "to_rgbs.6.upsample.kernel", "to_rgbs.6.conv.weight", "to_rgbs.6.conv.modulation.weight", "to_rgbs.6.conv.modulation.bias", "to_rgbs.7.bias", "to_rgbs.7.upsample.kernel", "to_rgbs.7.conv.weight", "to_rgbs.7.conv.modulation.weight", "to_rgbs.7.conv.modulation.bias", "noises.noise_13", "noises.noise_14", "noises.noise_15", "noises.noise_16". size mismatch for convs.6.conv.weight: copying a param with shape torch.Size([1, 512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 512, 3, 3]). size mismatch for convs.6.activate.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for convs.7.conv.weight: copying a param with shape torch.Size([1, 512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 256, 3, 3]). size mismatch for convs.7.conv.modulation.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]). size mismatch for convs.7.conv.modulation.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). I tried generate LMDB dataset with three different size(256\512\1024), and test with three kind dataset but still git same error above. I am wandering if the input layer size of model weights file (stylegan2-ffhq-config-f.pt) is mismatched?

rosinality commented 2 years ago

stylegan2-ffhq-config-f.pt is 1024px model, so you will need to use --size 1024.

Ethantequila commented 2 years ago

stylegan2-ffhq-config-f.pt is 1024px model, so you will need to use --size 1024.

Thanks for your guide, I tried with flag --size 1024, but I got another issue

python train.py \
>  --ckpt stylegan2-ffhq-config-f.pt \
>  --augment \
>   --size 1024 \
> /data/YNBY/zhixing/stylegan2-pytorch/datasets/ynby_lmdb_1024

error goes like ,

load model: /data/YNBY/zhixing/SAM/pretrained_models/stylegan2-ffhq-config-f.pt Traceback (most recent call last): File "train.py", line 494, in <module> g_optim.load_state_dict(ckpt["g_optim"]) KeyError: 'g_optim' If it means there is something wrong with the stylegan2-ffhq-config-f.pt model?

rosinality commented 2 years ago

Official pretrained checkpoints does not contain optimizer states. You can comment out that line.

1406428260 commented 2 years ago

Hello,same question,when i comment out that line "g_optim.load_state_dict(ckpt["g_optim"]) d_optim.load_state_dict(ckpt["d_optim"])" ,i got Traceback (most recent call last): File "train.py", line 531, in train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device) File "train.py", line 185, in train fake_pred = discriminator(fake_img) File "/home/godman/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/godman/zhf/workplace/stylegan2-pytorch/model.py", line 685, in forward group, -1, self.stddev_feat, channel // self.stddev_feat, height, width RuntimeError: shape '[4, -1, 1, 512, 4, 4]' is invalid for input of size 49152

1406428260 commented 2 years ago

stylegan2-ffhq-config-f.pt is 1024px model, so you will need to use --size 1024.

Thanks for your guide, I tried with flag --size 1024, but I got another issue

python train.py \
>  --ckpt stylegan2-ffhq-config-f.pt \
>  --augment \
>   --size 1024 \
> /data/YNBY/zhixing/stylegan2-pytorch/datasets/ynby_lmdb_1024

error goes like ,

load model: /data/YNBY/zhixing/SAM/pretrained_models/stylegan2-ffhq-config-f.pt Traceback (most recent call last): File "train.py", line 494, in <module> g_optim.load_state_dict(ckpt["g_optim"]) KeyError: 'g_optim' If it means there is something wrong with the stylegan2-ffhq-config-f.pt model?

Did you succeed?

1406428260 commented 2 years ago

This is the problem of batch size, which I have solved