IDEALLab / bezier-gan

Bézier Generative Adversarial Networks
MIT License
37 stars 21 forks source link

Assertion error: nan Fake data trained on Bezier GAN #5

Open veenan2791 opened 2 years ago

veenan2791 commented 2 years ago

I am using the Bezier GAN model on my own dataset with different closed shapes. However, it stops training at random training steps and produces nan fake data throwing an assertion error. The error occurs from the below code line in the beziergan/gan.py file.

if np.any(np.isnan(X_fake)):
                ind = np.any(np.isnan(X_fake), axis=(1,2,3))
                print(self.sess.run(ub, feed_dict={self.c: y_latent, self.z: noise})[ind])
                assert not np.any(np.isnan(X_fake))

Here is the kind of error I get.

1293: [D] real 0.002927 fake 0.246667 q 0.676213:  [G] fake 5.136545 reg 0.032323 q 0.660840
1294: [D] real 0.050245 fake 0.022135 q 0.675265:  [G] fake 4.619563 reg 0.033764 q 0.663499
1295: [D] real 0.168092 fake 1.496049 q 0.672238:  [G] fake 10.919476 reg 0.033241 q 0.633784
1296: [D] real 3.172950 fake 0.248389 q 0.629469:  [G] fake 4.550991 reg 0.035288 q 0.622464
1297: [D] real 0.820614 fake 4.543212 q 0.652395:  [G] fake 14.652376 reg 0.053193 q 0.677055
1298: [D] real 3.526137 fake 14.950486 q 74.995827:  [G] fake 25.814026 reg 0.048530 q 1192.471313
1299: [D] real 1.165939 fake 75.212219 q 3594630201344.000000:  [G] fake nan reg 0.052115 q nan
[[[ 0.]
  [nan]
  [nan]
  ...
  [nan]
  [nan]
  [nan]]

 [[ 0.]
  [nan]
  [nan]
  ...
  [nan]
  [nan]
  [nan]]

 [[ 0.]
  [nan]
  [nan]
  ...
  [nan]
  [nan]
  [nan]]

AssertionError                            Traceback (most recent call last)
<ipython-input-14-98cd6b708c54> in <module>()
      1 model = GAN(latent_dim, noise_dim, X_train.shape[1], bspline_degree, bounds)
----> 2 mod1 = model.train(X_train, batch_size=batch_size, train_steps=5000, directory=directory)

<ipython-input-8-fc4ff70068f4> in train(self, X_train, train_steps, batch_size, save_interval, directory)
    374                 print(self.sess.run(ub, feed_dict={self.c: y_latent, self.z: noise})[ind])
    375                 # try:
--> 376                 assert not np.any(np.isnan(X_fake))
    377 
    378                 # the errror_message provided by the user gets printed

AssertionError: 
 ...

Does anyone know the reason behind it and how to resolve the issue? It would be a great help for me.

wchen459 commented 2 years ago

There are two thing you might need to check:

  1. Does the code always end up with nan losses? Sometimes this happens because of the stochasticity of the random initialization, batch sampling, etc. and will not occur if rerunning the code.
  2. Are the model architecture and training hyperparameters properly configured to adjust to your new data? This is more likely your case because I can see high generator losses before nan appears, which means your generator is likely not powerful enough to compete with the discriminator. You may need to check, e.g., if the data contain values that the generator cannot generate, or if the data variability is too high/dataset size is too small, or if the learning rate/model complexity needs to be tuned.
veenan2791 commented 2 years ago

Yes, it could be the random initialization issue. I ran the code multiple times before. Although, it worked after that. It works without error once in at least 5 runs. I will also look at the hyperparameter tuning in the model architecture. Thank you very much for your suggestions.