tdeboissiere / DeepLearningImplementations

Implementation of recent Deep Learning papers
MIT License
1.81k stars 650 forks source link

WGAN-GP training MNIST model diverging? #31

Closed yif0 closed 7 years ago

yif0 commented 7 years ago

Hi and thanks for these great implementations. I modified the WGAN-GP code a little bit (basically just take the code from GAN_tf for MNIST) here, and tried to train the MNIST dataset instead of celebA. I got the following error, probably indicating training instability? Have you encountered similar problems? Thanks!

InvalidArgumentError (see above for traceback): Nan in summary histogram for: discriminator/conv2D_1/bn/beta_0/gradient [[Node: discriminator/conv2D_1/bn/beta_0/gradient = HistogramSummary[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"](discriminator/conv2D_1/bn/beta_0/gradient/tag, gradients_2/AddN_19/_113)]] [[Node: gradients_2/gradients/discriminator_2/conv2D_1/bn/moments/sufficient_statistics/var_ss_grad/Tile_grad/transpose/_106 = _SendT=DT_INT32, client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_783_gradients_2/gradients/discriminator_2/conv2D_1/bn/moments/sufficient_statistics/var_ss_grad/Tile_grad/transpose", _device="/job:localhost/replica:0/task:0/gpu:0"]]

Epoch 5: 50%|████▉ | 199/400 [01:10<01:11, 2.81it/s]

tdeboissiere commented 7 years ago

I had this issue as well. It is caused by the new component in the loss but I could not really pin down why. Lowering the learning rate or discarding batch normalization parts may help.