lucidrains / lightweight-gan

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two
MIT License
1.62k stars 220 forks source link

Getting into "NaN detected for generator or discriminator" loop every time #92

Open artucalvo opened 3 years ago

artucalvo commented 3 years ago

I have tried running the algorithm on Colab with different datasets (256, 512px), batch sizes (16, 32), aug probabilities (0.25, 0.40) and gradient_accumulate_every (4, 2, 1). However, I always get stuck in less than 1 hour into the NaN loop.

This is one execution example, where GP quickly gets to 10.00. Any thoughts on what is going on?

!lightweight_gan \
    --data $IMAGES_PATH \
    --results_dir $RESULTS_PATH \
    --models_dir $MODELS_PATH \
    --image-size 512 \
    --name LWG \
    --batch-size 32 \
    --gradient-accumulate-every 1 \
    --num-train-steps 1000000 \
    --save_every 1000 \
    --disc_output_size 5 \
    --aug-prob 0.4 \
    --aug-types [translation] \
    --amp \
    --new False
LWG</content/Training/>:   0% 0/1000000 [00:00<?, ?it/s]G: 1.87 | D: 3.59 | SS: 35.82
LWG</content/Training/>:   0% 45/1000000 [01:07<430:59:20,  1.55s/it]G: 1.09 | D: 0.42 | GP: 69.44 | SS: 0.42
LWG</content/Training/>:   0% 93/1000000 [02:30<406:26:33,  1.46s/it]G: 0.90 | D: 1.00 | GP: 12.81 | SS: 0.23
LWG</content/Training/>:   0% 149/1000000 [03:39<398:23:47,  1.43s/it]G: 1.37 | D: 1.33 | GP: 8.62 | SS: 0.25
LWG</content/Training/>:   0% 193/1000000 [04:42<399:40:24,  1.44s/it]G: -0.15 | D: 1.47 | GP: 1.17 | SS: 0.27
LWG</content/Training/>:   0% 249/1000000 [06:06<404:35:24,  1.46s/it]G: 0.55 | D: 1.59 | GP: 11.58 | SS: 0.22
LWG</content/Training/>:   0% 297/1000000 [07:14<392:55:39,  1.41s/it]G: 1.05 | D: 2.17 | GP: 3.01 | SS: 0.20
LWG</content/Training/>:   0% 345/1000000 [08:26<406:13:03,  1.46s/it]G: 0.00 | D: 1.82 | GP: 3.61 | SS: 0.32
LWG</content/Training/>:   0% 393/1000000 [09:34<396:13:20,  1.43s/it]G: 0.88 | D: 1.76 | GP: 0.54 | SS: 0.32
LWG</content/Training/>:   0% 449/1000000 [10:57<397:54:06,  1.43s/it]G: 0.56 | D: 1.48 | GP: 0.37 | SS: 0.39
LWG</content/Training/>:   0% 497/1000000 [12:06<394:28:55,  1.42s/it]G: 0.65 | D: 1.81 | GP: 0.51 | SS: 0.34
LWG</content/Training/>:   0% 545/1000000 [13:16<396:17:34,  1.43s/it]G: -0.23 | D: 1.93 | GP: 0.40 | SS: 0.34
LWG</content/Training/>:   0% 593/1000000 [14:25<391:16:11,  1.41s/it]G: -0.31 | D: 1.95 | GP: 0.22 | SS: 0.27
LWG</content/Training/>:   0% 649/1000000 [15:47<401:20:09,  1.45s/it]G: -0.26 | D: 1.94 | GP: 10.00 | SS: 0.29
LWG</content/Training/>:   0% 697/1000000 [16:54<391:38:29,  1.41s/it]G: -0.39 | D: 2.19 | GP: 10.00 | SS: 0.33
LWG</content/Training/>:   0% 745/1000000 [18:05<397:02:57,  1.43s/it]G: -0.21 | D: 2.16 | GP: 10.00 | SS: 0.32
LWG</content/Training/>:   0% 793/1000000 [19:13<391:04:03,  1.41s/it]G: -0.03 | D: 1.65 | GP: 10.00 | SS: 0.36
LWG</content/Training/>:   0% 849/1000000 [20:36<396:24:19,  1.43s/it]G: -0.22 | D: 2.14 | GP: 10.00 | SS: 0.30
LWG</content/Training/>:   0% 897/1000000 [21:44<395:12:13,  1.42s/it]G: -0.29 | D: 2.20 | GP: 10.00 | SS: 0.28
LWG</content/Training/>:   0% 945/1000000 [22:54<393:23:01,  1.42s/it]G: -0.45 | D: 2.20 | GP: 10.00 | SS: 0.33
LWG</content/Training/>:   0% 993/1000000 [24:02<391:36:45,  1.41s/it]G: -0.37 | D: 2.16 | GP: 10.00 | SS: 0.33
LWG</content/Training/>:   0% 1049/1000000 [25:25<401:18:43,  1.45s/it]G: -0.37 | D: 2.17 | GP: 10.00 | SS: 0.33
LWG</content/Training/>:   0% 1097/1000000 [26:33<394:19:55,  1.42s/it]G: -0.43 | D: 2.28 | GP: 10.00 | SS: 0.30
LWG</content/Training/>:   0% 1145/1000000 [27:45<399:24:02,  1.44s/it]G: -0.45 | D: 2.18 | GP: 10.00 | SS: 0.38
LWG</content/Training/>:   0% 1193/1000000 [28:52<390:31:48,  1.41s/it]G: -0.42 | D: 2.27 | GP: 10.00 | SS: 0.28
LWG</content/Training/>:   0% 1249/1000000 [30:15<397:12:58,  1.43s/it]G: -0.49 | D: 2.21 | GP: 10.00 | SS: 0.31
LWG</content/Training/>:   0% 1297/1000000 [31:24<395:35:03,  1.43s/it]G: -0.47 | D: 2.43 | GP: 10.00 | SS: 0.36
LWG</content/Training/>:   0% 1345/1000000 [32:35<397:08:04,  1.43s/it]G: -0.68 | D: 2.40 | GP: 10.00 | SS: 0.33
LWG</content/Training/>:   0% 1393/1000000 [33:43<392:13:41,  1.41s/it]G: -0.73 | D: 2.44 | GP: 10.00 | SS: 0.33
LWG</content/Training/>:   0% 1449/1000000 [35:06<400:23:47,  1.44s/it]G: -0.88 | D: 2.71 | GP: 10.00 | SS: 0.30
LWG</content/Training/>:   0% 1497/1000000 [36:13<391:56:53,  1.41s/it]G: -0.52 | D: 2.41 | GP: 10.00 | SS: 0.34
LWG</content/Training/>:   0% 1545/1000000 [37:26<403:18:44,  1.45s/it]G: -0.85 | D: 2.75 | GP: 10.00 | SS: 0.31
LWG</content/Training/>:   0% 1593/1000000 [38:33<391:32:21,  1.41s/it]G: -1.02 | D: 2.75 | GP: 10.00 | SS: 0.33
LWG</content/Training/>:   0% 1649/1000000 [39:56<394:39:26,  1.42s/it]G: -0.71 | D: 2.79 | GP: 10.00 | SS: 0.39
LWG</content/Training/>:   0% 1697/1000000 [41:05<394:23:49,  1.42s/it]G: -0.75 | D: 3.20 | GP: 10.00 | SS: 0.32
LWG</content/Training/>:   0% 1745/1000000 [42:15<395:36:40,  1.43s/it]G: -0.52 | D: 3.16 | GP: 10.00 | SS: 0.28
LWG</content/Training/>:   0% 1793/1000000 [43:24<394:27:18,  1.42s/it]G: -1.22 | D: 2.72 | GP: 10.00 | SS: 0.34
LWG</content/Training/>:   0% 1809/1000000 [43:50<415:27:46,  1.50s/it]NaN detected for generator or discriminator. Loading from checkpoint #1
loading from version 0.20.2
KingOfCashews commented 3 years ago

This happens when you use the AMP flag, I found the same happened in the StyleGAN2-PyTorch implementation when using the fp16 flag there, so seems that the models collapse quite quickly after initializing. Works just fine when omitting AMP, albeit slower and more memory intensive.

ckyleda commented 3 years ago

I get the same problem without AMP; then it tries to load always from checkpoint 0 (I think this may be a logging error). The GAN in general seems to be highly unstable.