open-mmlab / mmgeneration

MMGeneration is a powerful toolkit for generative models, based on PyTorch and MMCV.
https://mmgeneration.readthedocs.io/en/latest/
Apache License 2.0
1.88k stars 226 forks source link

StyleGAN2-ADA: Training differences compared to NVlabs implementation #394

Open mvidela31 opened 2 years ago

mvidela31 commented 2 years ago

Hi everybody,

I tried to make transfer learning from the stylegan2_config-f_lsun-car_384x512 pretrained model on a custom dataset of trucks images (with the same aspect ratio), but the training seems to differ from what I got using the Pytorch StyleGAN2-ADA official implementation.

With the NVlabs implementation (using --cfg paper512), the visualized training samples shows a smooth conversion of the car images from the pretrained model into the objective trucks. However, using the mmgen implementation, the visualized training samples shows that the initial car images are quickly converted into noise (unintelligible images), and after that they tried to capture the objective mode (the truck images) but with a worst image quality.

I suspect that the observed training difference is due to a different choice of hyperparameters, but the default hyperparameters from both implementations (NVlabs and mmgen) seems to be almost the same.

Am I missing some other important hyperparameters or the observed training difference is due to intrinsic implementation differences?

Output directory: /content/drive/TrucksGAN/00000-rgb_resize-paper512-resumecustom Training data: ./datasets/rgb.zip Training duration: 25000 kimg Number of GPUs: 1 Number of images: 537 Image resolution: 256 Conditional model: False Dataset x-flips: False

Creating output directory... Launching processes... Loading training set...

Num images: 537 Image shape: [3, 512, 512] Label shape: [0]

Constructing networks... Resuming from "./checkpoints/stylegan2-car-config-f.pkl" Setting up PyTorch plugin "bias_act_plugin"... Done. Setting up PyTorch plugin "upfirdn2d_plugin"... Done.

Generator Parameters Buffers Output shape Datatype


mapping.fc0 262656 - [8, 512] float32 mapping.fc1 262656 - [8, 512] float32 mapping.fc2 262656 - [8, 512] float32 mapping.fc3 262656 - [8, 512] float32 mapping.fc4 262656 - [8, 512] float32 mapping.fc5 262656 - [8, 512] float32 mapping.fc6 262656 - [8, 512] float32 mapping.fc7 262656 - [8, 512] float32 mapping - 512 [8, 16, 512] float32 synthesis.b4.conv1 2622465 32 [8, 512, 4, 4] float32 synthesis.b4.torgb 264195 - [8, 3, 4, 4] float32 synthesis.b4:0 8192 16 [8, 512, 4, 4] float32 synthesis.b4:1 - - [8, 512, 4, 4] float32 synthesis.b8.conv0 2622465 80 [8, 512, 8, 8] float32 synthesis.b8.conv1 2622465 80 [8, 512, 8, 8] float32 synthesis.b8.torgb 264195 - [8, 3, 8, 8] float32 synthesis.b8:0 - 16 [8, 512, 8, 8] float32 synthesis.b8:1 - - [8, 512, 8, 8] float32 synthesis.b16.conv0 2622465 272 [8, 512, 16, 16] float32 synthesis.b16.conv1 2622465 272 [8, 512, 16, 16] float32 synthesis.b16.torgb 264195 - [8, 3, 16, 16] float32 synthesis.b16:0 - 16 [8, 512, 16, 16] float32 synthesis.b16:1 - - [8, 512, 16, 16] float32 synthesis.b32.conv0 2622465 1040 [8, 512, 32, 32] float32 synthesis.b32.conv1 2622465 1040 [8, 512, 32, 32] float32 synthesis.b32.torgb 264195 - [8, 3, 32, 32] float32 synthesis.b32:0 - 16 [8, 512, 32, 32] float32 synthesis.b32:1 - - [8, 512, 32, 32] float32 synthesis.b64.conv0 2622465 4112 [8, 512, 64, 64] float16 synthesis.b64.conv1 2622465 4112 [8, 512, 64, 64] float16 synthesis.b64.torgb 264195 - [8, 3, 64, 64] float16 synthesis.b64:0 - 16 [8, 512, 64, 64] float16 synthesis.b64:1 - - [8, 512, 64, 64] float32 synthesis.b128.conv0 1442561 16400 [8, 256, 128, 128] float16 synthesis.b128.conv1 721409 16400 [8, 256, 128, 128] float16 synthesis.b128.torgb 132099 - [8, 3, 128, 128] float16 synthesis.b128:0 - 16 [8, 256, 128, 128] float16 synthesis.b128:1 - - [8, 256, 128, 128] float32 synthesis.b256.conv0 426369 65552 [8, 128, 256, 256] float16 synthesis.b256.conv1 213249 65552 [8, 128, 256, 256] float16 synthesis.b256.torgb 66051 - [8, 3, 256, 256] float16 synthesis.b256:0 - 16 [8, 128, 256, 256] float16 synthesis.b256:1 - - [8, 128, 256, 256] float32 synthesis.b512.conv0 139457 262160 [8, 64, 512, 512] float16 synthesis.b512.conv1 69761 262160 [8, 64, 512, 512] float16 synthesis.b512.torgb 33027 - [8, 3, 512, 512] float16 synthesis.b512:0 - 16 [8, 64, 512, 512] float16 synthesis.b512:1 - - [8, 64, 512, 512] float32


Total 30276583 699904 - -

Discriminator Parameters Buffers Output shape Datatype


b512.fromrgb 256 16 [8, 64, 512, 512] float16 b512.skip 8192 16 [8, 128, 256, 256] float16 b512.conv0 36928 16 [8, 64, 512, 512] float16 b512.conv1 73856 16 [8, 128, 256, 256] float16 b512 - 16 [8, 128, 256, 256] float16 b256.skip 32768 16 [8, 256, 128, 128] float16 b256.conv0 147584 16 [8, 128, 256, 256] float16 b256.conv1 295168 16 [8, 256, 128, 128] float16 b256 - 16 [8, 256, 128, 128] float16 b128.skip 131072 16 [8, 512, 64, 64] float16 b128.conv0 590080 16 [8, 256, 128, 128] float16 b128.conv1 1180160 16 [8, 512, 64, 64] float16 b128 - 16 [8, 512, 64, 64] float16 b64.skip 262144 16 [8, 512, 32, 32] float16 b64.conv0 2359808 16 [8, 512, 64, 64] float16 b64.conv1 2359808 16 [8, 512, 32, 32] float16 b64 - 16 [8, 512, 32, 32] float16 b32.skip 262144 16 [8, 512, 16, 16] float32 b32.conv0 2359808 16 [8, 512, 32, 32] float32 b32.conv1 2359808 16 [8, 512, 16, 16] float32 b32 - 16 [8, 512, 16, 16] float32 b16.skip 262144 16 [8, 512, 8, 8] float32 b16.conv0 2359808 16 [8, 512, 16, 16] float32 b16.conv1 2359808 16 [8, 512, 8, 8] float32 b16 - 16 [8, 512, 8, 8] float32 b8.skip 262144 16 [8, 512, 4, 4] float32 b8.conv0 2359808 16 [8, 512, 8, 8] float32 b8.conv1 2359808 16 [8, 512, 4, 4] float32 b8 - 16 [8, 512, 4, 4] float32 b4.mbstd - - [8, 513, 4, 4] float32 b4.conv 2364416 16 [8, 512, 4, 4] float32 b4.fc 4194816 - [8, 512] float32 b4.out 513 - [8, 1] float32


Total 28982849 480 - -

Setting up augmentation... Distributing across 1 GPUs... Setting up training phases... Exporting sample images... Initializing logs... Training for 25000 kimg...

tick 0 kimg 0.1 time 1m 37s sec/tick 33.6 sec/kimg 525.64 maintenance 63.3 cpumem 5.95 gpumem 10.51 augment 0.000


- **MMgen training logs:**

2022-08-23 14:28:37,271 - mmgen - INFO - Environment info:

sys.platform: linux Python: 3.7.13 (default, Apr 24 2022, 01:04:09) [GCC 7.5.0] CUDA available: True CUDA_HOME: /usr/local/cuda NVCC: Build cuda_11.1.TC455_06.29190527_0 GPU 0: Tesla P100-PCIE-16GB GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0 PyTorch: 1.10.0+cu111 PyTorch compiling details: PyTorch built with:

TorchVision: 0.11.0+cu111 OpenCV: 4.6.0 MMCV: 1.5.0 MMGen: 0.7.1+ MMCV Compiler: GCC 7.3 MMCV CUDA Compiler: 11.1

2022-08-23 14:28:37,589 - mmgen - INFO - Distributed training: True 2022-08-23 14:28:37,797 - mmgen - INFO - Config: dataset_type = 'UnconditionalImageDataset' train_pipeline = [ dict(type='LoadImageFromFile', key='real_img', io_backend='disk'), dict(type='Resize', keys=['real_img'], scale=(512, 384)), dict( type='NumpyPad', keys=['real_img'], padding=((64, 64), (0, 0), (0, 0))), dict( type='Normalize', keys=['real_img'], mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False), dict(type='ImageToTensor', keys=['real_img']), dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) ] data = dict( samples_per_gpu=4, workers_per_gpu=4, train=dict( type='RepeatDataset', times=5, dataset=dict( type='UnconditionalImageDataset', imgs_root='/content/data/rgb', pipeline=[ dict( type='LoadImageFromFile', key='real_img', io_backend='disk'), dict(type='Resize', keys=['real_img'], scale=(512, 384)), dict( type='NumpyPad', keys=['real_img'], padding=((64, 64), (0, 0), (0, 0))), dict( type='Normalize', keys=['real_img'], mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False), dict(type='ImageToTensor', keys=['real_img']), dict( type='Collect', keys=['real_img'], meta_keys=['real_img_path']) ])), val=dict( type='UnconditionalImageDataset', imgs_root='/content/data/rgb', pipeline=[ dict(type='LoadImageFromFile', key='real_img', io_backend='disk'), dict(type='Resize', keys=['real_img'], scale=(512, 384)), dict( type='NumpyPad', keys=['real_img'], padding=((64, 64), (0, 0), (0, 0))), dict( type='Normalize', keys=['real_img'], mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False), dict(type='ImageToTensor', keys=['real_img']), dict( type='Collect', keys=['real_img'], meta_keys=['real_img_path']) ])) d_reg_interval = 16 g_reg_interval = 4 g_reg_ratio = 0.8 d_reg_ratio = 0.9411764705882353 model = dict( type='StaticUnconditionalGAN', generator=dict( type='StyleGANv2Generator', out_size=512, style_channels=512), discriminator=dict( type='ADAStyleGAN2Discriminator', in_size=512, data_aug=dict( type='ADAAug', aug_pipeline=dict( xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), ada_kimg=500)), gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'), disc_auxiliary_loss=dict( type='R1GradientPenalty', loss_weight=80.0, interval=16, norm_mode='HWC', data_info=dict(real_data='real_imgs', discriminator='disc')), gen_auxiliary_loss=dict( type='GeneratorPathRegularizer', loss_weight=8.0, pl_batch_shrink=2, interval=4, data_info=dict(generator='gen', num_batches='batch_size'))) train_cfg = dict(use_ema=True) test_cfg = None optimizer = dict( generator=dict(type='Adam', lr=0.0016, betas=(0, 0.9919919678228657)), discriminator=dict( type='Adam', lr=0.0018823529411764706, betas=(0, 0.9905854573074332))) checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=40) log_config = dict( interval=100, hooks=[ dict(type='TextLoggerHook', by_epoch=False), dict( type='MlflowLoggerHook', exp_name='images-generation', log_model=False) ]) custom_hooks = [ dict( type='VisualizeUnconditionalSamples', output_dir='training_samples', interval=5000), dict( type='ExponentialMovingAverageHook', module_keys=('generator_ema', ), interval=1, interp_cfg=dict(momentum=0.9977843871238888), priority='VERY_HIGH') ] runner = dict( type='DynamicIterBasedRunner', is_dynamic_ddp=True, pass_training_status=True) dist_params = dict(backend='nccl') log_level = 'INFO' load_from = '/content/drive/MyDrive/MMGen_GenerationTrain/pretrained/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth' resume_from = None workflow = [('train', 10000)] find_unused_parameters = True cudnn_benchmark = True opencv_num_threads = 0 mp_start_method = 'fork' ema_half_life = 10.0 lr_config = None total_iters = 100002 metrics = dict( fid50k=dict( type='FID', num_images=50000, inception_pkl='/content/inception_data.pkl', inception_args=dict(type='StyleGAN'))) evaluation = None work_dir = '/content/drive/MyDrive/MMGen_GenerationTrain' gpu_ids = range(0, 1)

2022-08-23 14:28:37,797 - mmgen - INFO - Set random seed to 0, deterministic: False, use_rank_shift: False 2022-08-23 14:28:38,809 - mmgen - INFO - dataset_name: <class 'mmgen.datasets.unconditional_image_dataset.UnconditionalImageDataset'>, total 537 images in imgs_root: /content/data/rgb fatal: not a git repository (or any of the parent directories): .git /usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked)) 2022-08-23 14:28:43,351 - mmgen - INFO - load checkpoint from local path: /content/drive/MyDrive/COPEC/MMGen_GenerationTrain/pretrained/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth 2022-08-23 14:28:44,040 - mmgen - WARNING - The model and loaded state dict do not match exactly

missing keys in source state_dict: discriminator.ada_aug.log_buffer, discriminator.ada_aug.aug_pipeline.p, discriminator.ada_aug.aug_pipeline.Hz_geom, discriminator.ada_aug.aug_pipeline.Hz_fbank

2022-08-23 14:28:44,056 - mmgen - INFO - Start running, host: root@aa711de74988, work_dir: /content/drive/MyDrive/COPEC/MMGen_GenerationTrain 2022-08-23 14:28:44,056 - mmgen - INFO - workflow: [('train', 10000)], max: 100002 iters 2022-08-23 14:28:44,057 - mmgen - INFO - Checkpoints will be saved to /content/drive/MyDrive/COPEC/MMGen_GenerationTrain/ckpt/MMGen_GenerationTrain by HardDiskBackend. 2022/08/23 14:28:44 INFO mlflow.tracking.fluent: Experiment with name 'images-generation' does not exist. Creating a new experiment. /usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked)) 2022-08-23 14:30:51,461 - mmgen - INFO - Iter [100/100002] lr_generator: 1.600e-03 lr_discriminator: 1.882e-03, eta: 1 day, 11:17:19, time: 1.272, data_time: 0.008, memory: 13989, loss_disc_fake_g: 158.7404, loss_path_regular: 2.7484, loss: 0.1517, loss_disc_fake: 0.0175, loss_disc_real: 0.0909, loss_r1_gp: 0.6181, augment: 0.0016

mvidela31 commented 2 years ago

Hi all,

I realized that using --cfg paper512 in the NVLabs implementation modifies the following hyperparameters:

# NVLabs config
paper512_cfg = dict(ref_gpus=8,  kimg=25000,  mb=64, mbstd=8,  fmaps=1, lrate=0.0025, gamma=0.5,  ema=20,  ramp=None, map=8)

So I also modified the MMSegmentation config file accordingly:

# MMsegmentation config
cfg.ema_half_life = paper512_cfg['ema'] # Defaults to 10.0
cfg.optimizer.generator.lr = paper512_cfg['lrate'] # Defaults to 0.0016
cfg.optimizer.discriminator.lr = paper512_cfg['lrate'] # Defaults to 0.0018823529411764706
### Here I assumed that 'nvlabs_gamma'=='mmseg_loss_weight' / 2
cfg.model.disc_auxiliary_loss.loss_weight = paper512_cfg['gamma'] * 2 # Defaults to 80.0.

In addition, I also modified some other hyperparameters as their default value from the MMSegmentation implementation differs from the NVlabs implementation:

# NVLabs default values: style_mixing_prob=0.9 (OK), r1_gamma=10, pl_batch_shrink=2 (OK), pl_decay=0.01 (OK), pl_weight=2
# MMsegmentation config
cfg.model.gen_auxiliary_loss.loss_weight = 2.0 # Defaults to 8.0

I also checked that all other default hyperparameters were the same for both implementations (note that I removed the dict(type='Flip', keys=['real_img'], direction='horizontal') operation from data pipelines, since I used "xflip": false in the NVlabs implementation). However, with all the mentioned changes, the training performance is even worst compared to the previous hyperparameter configuration.

Could someone help me find the cause of the observed training difference between the MMsegmentation and NVlabs implementations? Is there a way to replicate the same training performance of the NVlabs implementation?

zengyh1900 commented 1 year ago

Please check this issue. @plyfager

plyfager commented 1 year ago

Sorry for responding so late. We'll have a look and reply soon.