eriklindernoren / Keras-GAN

Keras implementations of Generative Adversarial Networks.
MIT License
9.18k stars 3.14k forks source link

nan during training in CycleGAN #87

Open yhgon opened 5 years ago

yhgon commented 5 years ago

Did you see NaN in the training? configuration tensorflow : 1.11.0 keras : 2.1.6 GPU : K80 model : CycleGAN dataset : apple2orange

I've test multiple models. aggan, wgan, dcgan , pix2pix works well in 1,000 epochs in my environment. when I try to test cycleGAN, it train well.. but I got NaN in epoch 67.

I've modified the code to use BatchNormalization with batch(32) instead of Instance Normalization

gan.train(epochs=200, batch_size=64, sample_interval=15)

in Epoch 66, I got the result

/usr/local/lib/python3.6/dist-packages/scipy/misc/pilutil.py:482: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.
  if issubdtype(ts, int):
/usr/local/lib/python3.6/dist-packages/scipy/misc/pilutil.py:485: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  elif issubdtype(type(size), float):
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py:975: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
  'Discrepancy between trainable weights and collected trainable'
[Epoch 0/200] [Batch 0/15] [D loss: 10.340113, acc:  16%] [G loss: 36.504208, adv: 9.286882, recon: 0.813415, id: 1.059065] time: 0:00:46.796445 
[Epoch 0/200] [Batch 1/15] [D loss: 10.501231, acc:  14%] [G loss: 26.232471, adv: 4.194573, recon: 0.792024, id: 1.146999] time: 0:00:51.253706 
[Epoch 0/200] [Batch 2/15] [D loss: 7.842362, acc:  17%] [G loss: 22.555479, adv: 3.907498, recon: 0.631787, id: 1.116496] time: 0:00:55.256927 
[Epoch 0/200] [Batch 3/15] [D loss: 3.473462, acc:  21%] [G loss: 17.249832, adv: 2.483031, recon: 0.506841, id: 1.100872] time: 0:00:59.269601
[Epoch 67/200] [Batch 7/15] [D loss: 0.071113, acc:  95%] [G loss: 1.885691, adv: 0.042987, recon: 0.077898, id: 0.101703] time: 1:05:19.633317 
[Epoch 67/200] [Batch 8/15] [D loss: 0.096040, acc:  91%] [G loss: 1.872224, adv: 0.072618, recon: 0.075492, id: 0.099593] time: 1:05:23.721828 
[Epoch 67/200] [Batch 9/15] [D loss: 0.151695, acc:  80%] [G loss: 2.235736, adv: 0.179553, recon: 0.082687, id: 0.102930] time: 1:05:27.831364 
[Epoch 67/200] [Batch 10/15] [D loss: 0.140757, acc:  83%] [G loss: 1.851864, adv: 0.066634, recon: 0.074924, id: 0.101068] time: 1:05:31.932562 
[Epoch 67/200] [Batch 11/15] [D loss: 0.093928, acc:  89%] [G loss: 2.103337, adv: 0.081708, recon: 0.085628, id: 0.117779] time: 1:05:36.041804 
[Epoch 67/200] [Batch 12/15] [D loss: nan, acc:  60%] [G loss: 00nan, adv: 00nan, recon: 0.075524, id: 0.119893] time: 1:05:40.103866 
[Epoch 67/200] [Batch 13/15] [D loss: nan, acc:  25%] [G loss: 00nan, adv: 00nan, recon: 00nan, id: 0.092526] time: 1:05:44.104335 
[Epoch 68/200] [Batch 0/15] [D loss: nan, acc:  12%] [G loss: 00nan, adv: 00nan, recon: 00nan, id: 00nan] time: 1:05:48.023691 
/usr/local/lib/python3.6/dist-packages/numpy/core/_methods.py:29: RuntimeWarning: invalid value encountered in reduce
  return umr_minimum(a, axis, None, out, keepdims)
[Epoch 68/200] [Batch 1/15] [D loss: nan, acc:   0%] [G loss: 00nan, adv: 00nan, recon: 00nan, id: 00nan] time: 1:05:52.220977 
[Epoch 68/200] [Batch 2/15] [D loss: nan, acc:   0%] [G loss: 00nan, adv: 00nan, recon: 00nan, id: 00nan] time: 1:05:56.135995 
[Epoch 68/200] [Batch 3/15] [D loss: nan, acc:   0%] [G loss: 00nan, adv: 00nan, recon: 00nan, id: 00nan] time: 1:06:00.045664 
[Epoch 68/200] [Batch 4/15] [D loss: nan, acc:   0%] [G loss: 00nan, adv: 00nan, recon: 00nan, id: 00nan] time: 1:06:03.959619
yhgon commented 5 years ago

I've modified the code to use BatchNormalization with batch(32) instead of Instance Normalization

below is my code the difference for generator.

            d = BatchNormalization()(d) #d = InstanceNormalization()(d)  
            u = BatchNormalization()(u) # u = InstanceNormalization()(u)

for discriminator

            if normalization:
                d = BatchNormalization()(d)  # d = InstanceNormalization()(d)
            return d
Feiyu-Zhang commented 5 years ago

I've modified the code to use BatchNormalization with batch(32) instead of Instance Normalization

below is my code the difference for generator.

            d = BatchNormalization()(d) #d = InstanceNormalization()(d)  
            u = BatchNormalization()(u) # u = InstanceNormalization()(u)

for discriminator

            if normalization:
                d = BatchNormalization()(d)  # d = InstanceNormalization()(d)
            return d

Please have you found the reason?