lucidrains / stylegan2-pytorch

Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement
https://thispersondoesnotexist.com
MIT License
3.71k stars 587 forks source link

Number of channels in Generator is incorrect at higher resolutions, causing the model to have far too many parameters #81

Closed bob80333 closed 4 years ago

bob80333 commented 4 years ago

At 1024x1024 (default settings other than image_size), the model from this repo's channels look like this:

Gen block channels
64 -> 8192
Gen block channels
8192 -> 4096
Gen block channels
4096 -> 2048
Gen block channels
2048 -> 1024
Gen block channels
1024 -> 512
Gen block channels
512 -> 256
Gen block channels
256 -> 128
Gen block channels
128 -> 64
Gen block channels
64 -> 32

number of parameters: 1237966560

For contrast, the original tensorflow implementation looks like this for 1024x1024:

G                               Params    OutputShape          WeightShape     
---                             ---       ---                  ---             
latents_in                      -         (?, 512)             -               
labels_in                       -         (?, 0)               -               
lod                             -         ()                   -               
dlatent_avg                     -         (512,)               -               
G_mapping/latents_in            -         (?, 512)             -               
G_mapping/labels_in             -         (?, 0)               -               
G_mapping/Normalize             -         (?, 512)             -               
G_mapping/Dense0                262656    (?, 512)             (512, 512)      
G_mapping/Dense1                262656    (?, 512)             (512, 512)      
G_mapping/Dense2                262656    (?, 512)             (512, 512)      
G_mapping/Dense3                262656    (?, 512)             (512, 512)      
G_mapping/Dense4                262656    (?, 512)             (512, 512)      
G_mapping/Dense5                262656    (?, 512)             (512, 512)      
G_mapping/Dense6                262656    (?, 512)             (512, 512)      
G_mapping/Dense7                262656    (?, 512)             (512, 512)      
G_mapping/Broadcast             -         (?, 18, 512)         -               
G_mapping/dlatents_out          -         (?, 18, 512)         -               
Truncation/Lerp                 -         (?, 18, 512)         -               
G_synthesis/dlatents_in         -         (?, 18, 512)         -               
G_synthesis/4x4/Const           8192      (?, 512, 4, 4)       (1, 512, 4, 4)  
G_synthesis/4x4/Conv            2622465   (?, 512, 4, 4)       (3, 3, 512, 512)
G_synthesis/4x4/ToRGB           264195    (?, 3, 4, 4)         (1, 1, 512, 3)  
G_synthesis/8x8/Conv0_up        2622465   (?, 512, 8, 8)       (3, 3, 512, 512)
G_synthesis/8x8/Conv1           2622465   (?, 512, 8, 8)       (3, 3, 512, 512)
G_synthesis/8x8/Upsample        -         (?, 3, 8, 8)         -               
G_synthesis/8x8/ToRGB           264195    (?, 3, 8, 8)         (1, 1, 512, 3)  
G_synthesis/16x16/Conv0_up      2622465   (?, 512, 16, 16)     (3, 3, 512, 512)
G_synthesis/16x16/Conv1         2622465   (?, 512, 16, 16)     (3, 3, 512, 512)
G_synthesis/16x16/Upsample      -         (?, 3, 16, 16)       -               
G_synthesis/16x16/ToRGB         264195    (?, 3, 16, 16)       (1, 1, 512, 3)  
G_synthesis/32x32/Conv0_up      2622465   (?, 512, 32, 32)     (3, 3, 512, 512)
G_synthesis/32x32/Conv1         2622465   (?, 512, 32, 32)     (3, 3, 512, 512)
G_synthesis/32x32/Upsample      -         (?, 3, 32, 32)       -               
G_synthesis/32x32/ToRGB         264195    (?, 3, 32, 32)       (1, 1, 512, 3)  
G_synthesis/64x64/Conv0_up      2622465   (?, 512, 64, 64)     (3, 3, 512, 512)
G_synthesis/64x64/Conv1         2622465   (?, 512, 64, 64)     (3, 3, 512, 512)
G_synthesis/64x64/Upsample      -         (?, 3, 64, 64)       -               
G_synthesis/64x64/ToRGB         264195    (?, 3, 64, 64)       (1, 1, 512, 3)  
G_synthesis/128x128/Conv0_up    1442561   (?, 256, 128, 128)   (3, 3, 512, 256)
G_synthesis/128x128/Conv1       721409    (?, 256, 128, 128)   (3, 3, 256, 256)
G_synthesis/128x128/Upsample    -         (?, 3, 128, 128)     -               
G_synthesis/128x128/ToRGB       132099    (?, 3, 128, 128)     (1, 1, 256, 3)  
G_synthesis/256x256/Conv0_up    426369    (?, 128, 256, 256)   (3, 3, 256, 128)
G_synthesis/256x256/Conv1       213249    (?, 128, 256, 256)   (3, 3, 128, 128)
G_synthesis/256x256/Upsample    -         (?, 3, 256, 256)     -               
G_synthesis/256x256/ToRGB       66051     (?, 3, 256, 256)     (1, 1, 128, 3)  
G_synthesis/512x512/Conv0_up    139457    (?, 64, 512, 512)    (3, 3, 128, 64) 
G_synthesis/512x512/Conv1       69761     (?, 64, 512, 512)    (3, 3, 64, 64)  
G_synthesis/512x512/Upsample    -         (?, 3, 512, 512)     -               
G_synthesis/512x512/ToRGB       33027     (?, 3, 512, 512)     (1, 1, 64, 3)   
G_synthesis/1024x1024/Conv0_up  51297     (?, 32, 1024, 1024)  (3, 3, 64, 32)  
G_synthesis/1024x1024/Conv1     25665     (?, 32, 1024, 1024)  (3, 3, 32, 32)  
G_synthesis/1024x1024/Upsample  -         (?, 3, 1024, 1024)   -               
G_synthesis/1024x1024/ToRGB     16515     (?, 3, 1024, 1024)   (1, 1, 32, 3)   
G_synthesis/images_out          -         (?, 3, 1024, 1024)   -               
G_synthesis/noise0              -         (1, 1, 4, 4)         -               
G_synthesis/noise1              -         (1, 1, 8, 8)         -               
G_synthesis/noise2              -         (1, 1, 8, 8)         -               
G_synthesis/noise3              -         (1, 1, 16, 16)       -               
G_synthesis/noise4              -         (1, 1, 16, 16)       -               
G_synthesis/noise5              -         (1, 1, 32, 32)       -               
G_synthesis/noise6              -         (1, 1, 32, 32)       -               
G_synthesis/noise7              -         (1, 1, 64, 64)       -               
G_synthesis/noise8              -         (1, 1, 64, 64)       -               
G_synthesis/noise9              -         (1, 1, 128, 128)     -               
G_synthesis/noise10             -         (1, 1, 128, 128)     -               
G_synthesis/noise11             -         (1, 1, 256, 256)     -               
G_synthesis/noise12             -         (1, 1, 256, 256)     -               
G_synthesis/noise13             -         (1, 1, 512, 512)     -               
G_synthesis/noise14             -         (1, 1, 512, 512)     -               
G_synthesis/noise15             -         (1, 1, 1024, 1024)   -               
G_synthesis/noise16             -         (1, 1, 1024, 1024)   -               
images_out                      -         (?, 3, 1024, 1024)   -               
---                             ---       ---                  ---             
Total                           30370060                                       

In the original implementation, the highest number of channels a convolution has is 512, but in this implementation, it grows to 8192?

lucidrains commented 4 years ago

@bob80333 Hi Eric, this is a great finding. I originally transcribed this repository from another re-implementation, so it may not be completely accurate. Do you know if the discriminator also caps the channels at what I presume is the latent z dimension? Does this hold if you go even higher, to say 2048?

lucidrains commented 4 years ago

@bob80333 let me know, and I'll make the change!

bob80333 commented 4 years ago

Here's the Discriminator from the same run on the original repo, I'll try changing the latent dimension to see if the cap is the latent dimension.

D                     Params    OutputShape          WeightShape     
---                   ---       ---                  ---             
images_in             -         (?, 3, 1024, 1024)   -               
labels_in             -         (?, 0)               -               
1024x1024/FromRGB     128       (?, 32, 1024, 1024)  (1, 1, 3, 32)   
1024x1024/Conv0       9248      (?, 32, 1024, 1024)  (3, 3, 32, 32)  
1024x1024/Conv1_down  18496     (?, 64, 512, 512)    (3, 3, 32, 64)  
1024x1024/Skip        2048      (?, 64, 512, 512)    (1, 1, 32, 64)  
512x512/Conv0         36928     (?, 64, 512, 512)    (3, 3, 64, 64)  
512x512/Conv1_down    73856     (?, 128, 256, 256)   (3, 3, 64, 128) 
512x512/Skip          8192      (?, 128, 256, 256)   (1, 1, 64, 128) 
256x256/Conv0         147584    (?, 128, 256, 256)   (3, 3, 128, 128)
256x256/Conv1_down    295168    (?, 256, 128, 128)   (3, 3, 128, 256)
256x256/Skip          32768     (?, 256, 128, 128)   (1, 1, 128, 256)
128x128/Conv0         590080    (?, 256, 128, 128)   (3, 3, 256, 256)
128x128/Conv1_down    1180160   (?, 512, 64, 64)     (3, 3, 256, 512)
128x128/Skip          131072    (?, 512, 64, 64)     (1, 1, 256, 512)
64x64/Conv0           2359808   (?, 512, 64, 64)     (3, 3, 512, 512)
64x64/Conv1_down      2359808   (?, 512, 32, 32)     (3, 3, 512, 512)
64x64/Skip            262144    (?, 512, 32, 32)     (1, 1, 512, 512)
32x32/Conv0           2359808   (?, 512, 32, 32)     (3, 3, 512, 512)
32x32/Conv1_down      2359808   (?, 512, 16, 16)     (3, 3, 512, 512)
32x32/Skip            262144    (?, 512, 16, 16)     (1, 1, 512, 512)
16x16/Conv0           2359808   (?, 512, 16, 16)     (3, 3, 512, 512)
16x16/Conv1_down      2359808   (?, 512, 8, 8)       (3, 3, 512, 512)
16x16/Skip            262144    (?, 512, 8, 8)       (1, 1, 512, 512)
8x8/Conv0             2359808   (?, 512, 8, 8)       (3, 3, 512, 512)
8x8/Conv1_down        2359808   (?, 512, 4, 4)       (3, 3, 512, 512)
8x8/Skip              262144    (?, 512, 4, 4)       (1, 1, 512, 512)
4x4/MinibatchStddev   -         (?, 513, 4, 4)       -               
4x4/Conv              2364416   (?, 512, 4, 4)       (3, 3, 513, 512)
4x4/Dense0            4194816   (?, 512)             (8192, 512)     
Output                513       (?, 1)               (512, 1)        
scores_out            -         (?, 1)               -               
---                   ---       ---                  ---             
Total                 29012513  
bob80333 commented 4 years ago

It appears that the cap is separate from the latents:

Cap: https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py#L315

Latents: https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py#L254 https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py#L256 https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py#L259 https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py#L309

bob80333 commented 4 years ago

Also, the 4x4 constant should have the same number of channels as the first convolution (in the 1024x1024 case 512 channels).

lucidrains commented 4 years ago

@bob80333 gotcha! thanks for digging into this! I originally built the repository so people can get a first-hand experience of disentanglement, but now I see a bunch of people wanting to reach for higher resolutions, so I'll try to make it scale better!

bob80333 commented 4 years ago

I was interested in trying out the data augmentations that you added, but I wanted to make sure that the models were the same, so I created a 1024 resolution one in a colab notebook and noticed the difference in channels. I might do some more looking into the training details to try to ensure similar results between the nvidia implementation and this one, since I don't know of any other implementations with data augmentation.

lucidrains commented 4 years ago

@bob80333 done! https://github.com/lucidrains/stylegan2-pytorch/commit/d3207f4505b7c01723e1983e8a6d70ce330bb81c

lucidrains commented 4 years ago

@bob80333 the data augmentation technique is so simple, it almost doesn't warrant any implementation. Karras' own paper on data augmentation had some sort of curriculum learning built in based on some heuristic, but other papers did just fine augmenting the images with some fixed probability