larsmaaloee / auxiliary-deep-generative-models

Deep generative models for semi-supervised learning.
MIT License
109 stars 28 forks source link

nan encountered in run_mnist.py #1

Closed poolio closed 8 years ago

poolio commented 8 years ago

Thanks for putting this code up! Running the run_mnist.py script without any modifications I get nans at epoch 133. Here's the relevant output:

test 100-samples: 2.66%.
epoch=0131; time=115.93; lb=196.7493; lb-labeled=54.7495; lb-unlabeled=89.8326; test=2.53%; validation=2.52%;
epoch=0132; time=115.77; lb=196.7117; lb-labeled=54.7303; lb-unlabeled=89.8130; test=2.48%; validation=2.51%;
epoch=0133; time=115.43; lb=nan; lb-labeled=nan; lb-unlabeled=nan; test=90.20%; validation=90.09%;
epoch=0134; time=115.42; lb=nan; lb-labeled=nan; lb-unlabeled=nan; test=90.20%; validation=90.09%;

Is this a known issue? Do I need to tweak the learning rates? Thanks!

larsmaaloee commented 8 years ago

Thank you for the interest in the code. I believe that this is due to the batch normalisation not being implemented correctly at this stage (not included in the article either). Have you tried without batch normalisation? Batch normalisation needs to be weighted correctly in the unlabelled case. - it's on my todo list.

poolio commented 8 years ago

Convergence is much slower without batchnorm, but that fixed the issue with nans. Thanks!