Open kunrenzhilu opened 6 years ago
This is due to the numerical issue of \log. Add a tiny number, e.g. 1e-5, on the output of enc_std and dec_std will help.
This may also help..
eps = torch.finfo(torch.float32).eps
def _nll_bernoulli(self, theta, x):
return - torch.sum(x * torch.log(theta + eps) + (1 - x) * torch.log(1 - theta + eps))
Plus set a minimum stdev as Mr Song mentioned above.
Also, this model assumes all inputs are between 0.0 - 1.0, so normalize your data using (x - x.min) / (x.max - x.min) before passing in.
Where is the normalisation assumption made and how does it affect the model if a different normalisation is used?
Sorry, it's been a couple years since I looked at this model. However generally speaking, NN weights are initialized with Gaussian noise between -1.0 .. 1.0, so it's usually a good idea to feed in numbers in a similar range, otherwise you can get numerical instability as the gradients can be greater than 1.0, causing runaway updates. This is especially problematic when modeling variance... as gradients on variance have more variance themselves that the mean during training. Also, FP representation loses precision as it goes further from -1.0, 1.0. I believe 75% of the re presentable numbers in floating point are between -1.0 - 1.0, and only 25% in the rest of the range. The biggest problem you are going to have with modelling variance is the training process causing variance to become very small, and the likelihood therefore becoming large and positive, and going out of range of your representation. That's why adding a minimum epsilon to variance is a good idea, this ensures that the variance cannot go stupidly small and cause massive (and unrealistic) likelihoods.
Train Epoch: 3 [0/60000 (0%)] KLD Loss: 2.687659 NLL Loss: 73.599564 Train Epoch: 3 [2800/60000 (21%)] KLD Loss: 2.976363 NLL Loss: 78.757454 Train Epoch: 3 [5600/60000 (43%)] KLD Loss: 2.837864 NLL Loss: 78.958122 Train Epoch: 3 [8400/60000 (64%)] KLD Loss: nan NLL Loss: nan Train Epoch: 3 [11200/60000 (85%)] KLD Loss: nan NLL Loss: nan ====> Epoch: 3 Average loss: nan ====> Test set loss: KLD Loss = nan, NLL Loss = nan Train Epoch: 4 [0/60000 (0%)] KLD Loss: nan NLL Loss: nan