NVlabs / stylegan2-ada-pytorch

StyleGAN2-ADA - Official PyTorch implementation
https://arxiv.org/abs/2006.06676
Other
4.09k stars 1.16k forks source link

Transfer learning from a Nvida's Pre-trained StyleGAN (FFHQ) #279

Open Bearwithchris opened 1 year ago

Bearwithchris commented 1 year ago

Hi,

Utilize the pre-trained pkl file: https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-256x256.pkl. I've attempted to transfer learning (without augmentation) from (FFHQ->CelebA-HQ).

python train.py --outdir=~/training-runs --data=~/datasets/FFHQ/GenderTrainSamples_0.025.zip --gpus=1 --workers 1 --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-256x256.pkl --aug=noaug --kimg 200

However, when looking a the init generated images, I see this: image

but when checking the FID against the FFHQ dataset FID=~9.

Can anyone explain what is going?

JackywithaWhiteDog commented 1 year ago

Hi, I have faced the same problem and found out that the default configuration on the model is different from the pre-trained one.

Solution

Change the configuration from --cfg=auto (default) to --cfg=paper256 for this pre-trained model. (For other pre-trained models, use the same model configuration as they were trained)

fakes_init.png with --cfg=auto:

fakes_init_auto

fakes_init.png with --cfg=paper256:

fakes_init_paper256

Explanation

The configuration controls the model's channel_base, the number of the mapping network's layers, and the minibatch standard deviation layer of the discriminator. For instance, the mapping network has 8 layers with --cfg=paper256, while it has only 2 layers with --cfg=auto.

To keep the model structure the same as the pre-trained one, you should ensure that fmaps, map, and mdstd in the configuration are the same as it was trained.

https://github.com/NVlabs/stylegan2-ada-pytorch/blob/6f160b3d22b8b178ebe533a50d4d5e63aedba21d/train.py#L154-L161

https://github.com/NVlabs/stylegan2-ada-pytorch/blob/6f160b3d22b8b178ebe533a50d4d5e63aedba21d/train.py#L176-L183

In addition, when loading the pre-trained model, the function copy_params_and_buffers would ignore the unexpected parameters in the pre-trained model without informing of such inconsistency.

https://github.com/NVlabs/stylegan2-ada-pytorch/blob/6f160b3d22b8b178ebe533a50d4d5e63aedba21d/torch_utils/misc.py#L153-L160

githuboflk commented 1 year ago

Hi, I have faced the same problem and found out that the default configuration on the model is different from the pre-trained one.

Solution

Change the configuration from --cfg=auto (default) to --cfg=paper256 for this pre-trained model. (For other pre-trained models, use the same model configuration as they were trained)

fakes_init.png with --cfg=auto:

fakes_init_auto

fakes_init.png with --cfg=paper256:

fakes_init_paper256

Explanation

The configuration controls the model's channel_base, the number of the mapping network's layers, and the minibatch standard deviation layer of the discriminator. For instance, the mapping network has 8 layers with --cfg=paper256, while it has only 2 layers with --cfg=auto.

To keep the model structure the same as the pre-trained one, you should ensure that fmaps, map, and mdstd in the configuration are the same as it was trained.

https://github.com/NVlabs/stylegan2-ada-pytorch/blob/6f160b3d22b8b178ebe533a50d4d5e63aedba21d/train.py#L154-L161

https://github.com/NVlabs/stylegan2-ada-pytorch/blob/6f160b3d22b8b178ebe533a50d4d5e63aedba21d/train.py#L176-L183

In addition, when loading the pre-trained model, the function copy_params_and_buffers would ignore the unexpected parameters in the pre-trained model without informing of such inconsistency.

https://github.com/NVlabs/stylegan2-ada-pytorch/blob/6f160b3d22b8b178ebe533a50d4d5e63aedba21d/torch_utils/misc.py#L153-L160

@JackywithaWhiteDog Hi. Where did this weight come from? I only get 256*256 pretrained weight in transfer-lerarning folder. The website is different from the one in this issue.

JackywithaWhiteDog commented 2 months ago

Hi @githuboflk, sorry that I didn't notice your question. I also used the pre-trained weight provided in README as you mentioned.

However, I think the checkpoint in this issue is available at NVIDIA NGC Catalog.