y0ast / VAE-Torch

Implementation of Variational Auto-Encoder in Torch7
MIT License
267 stars 62 forks source link

Use of exp in the Reparametrization #3

Closed zencoding closed 9 years ago

zencoding commented 9 years ago

Hi, I am trying to follow DKingma/MWelling's paper at http://arxiv.org/pdf/1312.6114v10.pdf and your torch code, I am following everything in the code except that I am getting lost at use of exp in the Reparmeterization updateOutput and UpdateGradInput functions. I did not find any mention of using exponential in the paper or other implementations of VAE (in Theano/PyLearn2). Also is the fill of 0.5 just a arbitrary parameter value or is there any specific reason for it.

if torch.typename(input[1]) == 'torch.CudaTensor' then self.eps = self.eps:cuda() self.output = torch.CudaTensor():resizeAs(input[2]):fill(0.5) else self.output = torch.Tensor():resizeAs(input[2]):fill(0.5) end

BTW, I have reviewed 5 different implementations of Variational AutoEncoder,your torch implementation is the most concise and clear.

Thank You

AjayTalati commented 9 years ago

In Joost's code,

input[2] = log sigma^2

but I think it might be easier, and more stable, if you re-implement it with this change of variables,

input[1] = mu

input[2] = log (sigma)

input[3] = eps -- the samples from the prior p(z)

then

sigma = exp(input[2])

sigma^2 = exp(2 * input[2])

and

log(sigma^2) = (2 * input[2])

So the output is

z= mu + sigma * eps = mu + exp(ln(sigma)) * eps

and the gradients are is

dz/d(mu) = 1

dz/d(log(sigma)) = exp(ln(sigma)) * eps = exp( input[2] ) * ( input[3] )

dz/d(eps) = exp(ln(sigma)) = exp( input[2] )

Hope that helps? If need more help send me an email?

PS - I found the VAE (in Theano/PyLearn2) to be a waste of time - if you implement it in torch it's much clearer and quicker.

y0ast commented 9 years ago

Ajay is right, it is cleaner (and likely more numerically stable) to do it his way. I switched to doing it that way too in later projects.

One point on your math though: dz/d(log(sigma)) = exp(log(sigma)) * eps = exp(log(sigma)) * eps = input[2]:exp() * eps

I just tried doing this and pushed some code in this branch: https://github.com/y0ast/VAE-Torch/tree/sigma_squared, specifically this commit: https://github.com/y0ast/VAE-Torch/commit/6f5690e6150729c344bea4827bda1e4a2063f27f

Currently it's giving NaNs, because the KLD blows up (it's about 100x higher from the start). I am not sure how to fix this yet, I'll look at it again later.

AjayTalati commented 9 years ago

Hi Joost :)

sorry a typo, just fixed it.

So I'm implementing a LSTM VAE, I haven’t had time to do full VB for your VAE, as the LSTM VAE does full VB.

When I clean up the code I'll send it to you - it's a mess at the moment, but you'd like it :+1:

Also there's a couple of typos in your paper with Otto,

For the eqn of the estimator L(theta;X^{i}), above 2.2 Model . log((sigma^{2})) is missing a j subscript.

For the eqns for mu and log sigma, end should be enc

AjayTalati commented 9 years ago

Joost,

try initializing your parameters, with uniform weights, in the range (-0.08 , 0.08), i.e.

parameters, gradients = va:getParameters() init_weight = 0.08 parameters:uniform( -init_weight , init_weight ) ;

Or maybe try standard Gaussians, i.e. params ~ N(0,1)

also for the gradient output of your objective function opfunc, clamp them in the range (-1,1) or (-10,10), i.e.

-- clip gradient element-wise gradients:clamp(-1, 1)

also another trick, try adding an error catch for exponentials like,

exp(x) -> exp(min(x, 99999))

I'm experimenting with this in my nn.LatentLoss criterion module

I've found this to help sometimes

zencoding commented 9 years ago

Thanks, I understand the exp in the Reparameterization.

Most of the deep learning implementations work perfectly fine with MNIST or CIFAR, but when I start putting them through real natural images, their performance is very bad. Can you comment if this implementation (or the VariantionalDeconvnet) can work for larger images (224x224) , and if it does what are the challenges.

dpkingma commented 9 years ago

Hi all,

VAEs work fine on natural images, but some care needs to be taken:

Cheers, Durk

On Sun, Mar 22, 2015 at 6:35 AM, zencoding notifications@github.com wrote:

Thanks, I understand the exp in the Reparameterization.

Most of the deep learning implementations work perfectly fine with MNIST or CIFAR, but when I start putting them through real natural images, their performance is very bad. Can you comment if this implementation (or the VariantionalDeconvnet) can work for larger images (224x224) , and if it does what are the challenges.

— Reply to this email directly or view it on GitHub https://github.com/y0ast/VAE-Torch/issues/3#issuecomment-84524829.

AjayTalati commented 9 years ago

Hi Durk :)

How are you? Thanks for the great advice you gave me a while back - I made a lot of progress once I stopped trying to understand other peoples code, and just implemented it for myself from scratch as you advised!

_That's exactly what's happening !!!!!_

One source of instability in regular (unnormalized) VAEs, is that small changes in each of the parameters of the encoder can together lead to a large effects on the encoder outputs; e.g. one gradient step leading to a jump of several standard deviations of the encoder output. This problem usually becomes worse later in training, but is luckily quite easily prevented by weight normalization.

To be precise, I'm using LSTMs for the encoder and decoder. I'm finding that the outputs of the encoder, in particular the norm of the cell states, (yeah I know its not really an output), say |c_enc_t| is blowing up to infinity, towards the end of the forward pass. Seems to be exactly what you say above, as this happens from the middle of training, (after about 20 epochs), and never recovers.

I thought that this might be caused by,

a) back propagation through time, where after each SGD step you copy the LSTMs final states, (after the backward pass), to the initial states for the next evaluation, (forward pass), of the objective function. So I got rid of that?

b) simply having too many LSTM cells, cloned for the system, (not sure this is the right terminology, intuitively there's one clone of the system, (LSTM_enc, Q_sampler, LSTM_dec), for each glimpse of the attention mechanism). I was using 64 clones, I reduced it to 16, but it's still blowing up?

I'll try batch normalization or L2 normalization plus a rescaling parameter, as you suggest - it might take a while? Just to be sure - is it Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift, that you think I should try?

Thanks for the great advice, and your creations - the VAE and ADAM,

Best,

Aj

dpkingma commented 9 years ago

Hi Ajay,

Yes, that's the paper.

Regards, Durk

On Sun, Mar 22, 2015 at 5:26 PM, Ajay Talati notifications@github.com wrote:

Hi Durk :)

How are you? Thanks for the great advice you gave me a while back - I made a lot of progress once I stopped trying to understand other peoples code, and just implemented it for myself from scratch as you advised!

That's exactly what's happening !!!!!

One source of instability in regular (unnormalized) VAEs, is that small changes in each of the parameters of the encoder can together lead to a large effects on the encoder outputs; e.g. one gradient step leading to a jump of several standard deviations of the encoder output. This problem usually becomes worse later in training, but is luckily quite easily prevented by weight normalization.

To be precise, I'm using LSTMs for the encoder and decoder. I'm finding that the outputs of the encoder, in particular the norm of the cell states, say |c_enc_t| is blowing up to infinity, towards the end of the forward pass. Seems to be exactly what you say above, as this happens from the middle of training, (after about 20 epochs), and never recovers.

I thought that this might be caused by,

a) back propagation through time, where after each SGD step you copy the LSTMs final states, (after the backward pass), to the initial states for the next evaluation, (forward pass), of the objective function. So I got rid of that?

b) simply having too many LSTM cells, cloned for the system, (not sure this is the right terminology, intuitively there's one clone of the LSTM system for each glimpse of the attention mechanism). I was using 64 clones, I reduced it to 16, but it's still blowing up?

I'll try batch normalization as you suggest - it might take a while? Just to sure - is it Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift http://arxiv.org/pdf/1502.03167.pdf, that you think I should try?

Thanks for the great advice, and your creations - the VAE and ADAM,

Best,

Aj

— Reply to this email directly or view it on GitHub https://github.com/y0ast/VAE-Torch/issues/3#issuecomment-84647130.

zencoding commented 9 years ago

I tried the updated code and I also got nan from the very first epoch. I also tried the suggestions from Ajay but it never got out of nan. When I changed the learning rate to -0.00001, it come out of nan but the results were weird, it started at Lowerbound of -520 and towards epoch 200 reached very low Lowerbound value.

[================================ 1500/1500 =================================>] ETA: 0ms | Step: 0ms Epoch: 1 Lowerbound: -521.71322166922 time: 0.80144906044006 [================================ 1500/1500 =================================>] ETA: 0ms | Step: 0ms Epoch: 2 Lowerbound: -517.45266375482 time: 0.79507112503052 [================================ 1500/1500 =================================>] ETA: 0ms | Step: 0ms Epoch: 3 Lowerbound: -514.94652135643 time: 0.78467798233032 [================================ 1500/1500 =================================>] ETA: 0ms | Step: 0ms Epoch: 4 Lowerbound: -512.93061513467 time: 0.78672099113464 [================================ 1500/1500 =================================>]

Epoch: 201 Lowerbound: -1.4880559243132e+25 time: 0.78963804244995 [================================ 1500/1500 =================================>] ETA: 0ms | Step: 0ms Epoch: 202 Lowerbound: -4.303366228673e+40 time: 0.77805995941162 [================================ 1500/1500 =================================>] ETA: 0ms | Step: 0ms Epoch: 203 Lowerbound: -8.7918260891567e+28 time: 0.77181601524353 [================================ 1500/1500 =================================>] ETA: 0ms | Step: 0ms Epoch: 204 Lowerbound: -334050880319.71 time: 0.76645302772522 [================================ 1500/1500 =================================>] ETA: 0ms | Step: 0ms Epoch: 205 Lowerbound: -4.3466713285911e+26 time: 0.7693018913269 [================================ 1500/1500 =================================>] ETA: 0ms | Step: 0ms Epoch: 206 Lowerbound: -1.1311347150923e+43 time: 0.76804494857788 [================================ 1500/1500 =================================>] ETA: 0ms | Step: 0ms Epoch: 207 Lowerbound: -8213027234510.8 time: 0.76709008216858 [================================ 1500/1500 =================================>]

AjayTalati commented 9 years ago

Yep Joost's version uses gradient ascent not descent so you need to reverse the sign of the learning rate.

I'd check line 16 of the KLD criterion - I think he's got the grad of the KLD, w.r.t. log_sigma_t, different to me.

You might need to, multiply input[2], by 2. My implementation uses,

grad_KLD_matrix_wrt_log_sigma_t = torch.exp( torch.mul( input[2] , 2 ) ):add(-1)

It's not difficult to check the maths, just use the chain rule. Maybe he's right and I'm wrong? Best to check all the calculations yourself?

My experience is that feed-forward VAE on MNIST, using Adam should get sensible negll after 10 epochs, if not there's a problem. Instabilities seem to kick in after 20 epochs, so you really have to run it for 100 epochs to test your code. Also you don't need large values for hidden layers or latent variables, just to do a fast test run of your code, and using half the dataset works fine too.

y0ast commented 9 years ago

Hmm, I just fixed the gradient of the KLD. I think you flipped the sign Ajay and zencoding is using the right sign of the LR (negative).

KLD = 0.5 * sum[1 + 2_log(sigma) - mu^2 - sigma^2)] dKLD/dlog(sigma_i) = 0.5 * [2 - 2_exp(2* log(sigma)] = 1 - exp(2*log(sigma))

In Torch -> self.gradInput[2] = (-torch.exp(torch.mul(input[2],2))):add(1)

It's still unstable for me though, so I am on the hunt for another bug.

AjayTalati commented 9 years ago

Yep that looks right,

dKLD/dlog(sigma_i) = 0.5 * [2 - 2exp(2* log(sigma)] = 1 - exp(2*log(sigma))

if you're using gradient ascent

AjayTalati commented 9 years ago

Joost, try initializing your parameters with small isotropic Gaussian noise, i.e. something like

params, grad_params = va:getParameters()

params:randn(params:size()):mul(0.01) ; -- initialize parameters with small Gaussians

I've found this to be helpful ;)

zencoding commented 9 years ago

Ajay, I already tried the initializing with Guassian noise, it did not help. May be there is some other bug, I am still hunting for it :)

y0ast commented 9 years ago

That's what the sigmaInit variable is for in LinearVA, however that's treating the symptoms not the actual cause :).

I'll have another deeper look this weekend, thanks for spending some time to look at it zencoding and Ajay.

AjayTalati commented 9 years ago

Still not getting this to work guys ???

Any ideas where this is going wrong? It can only be in the re-parametrization or the KLD calculation?

I'm guessing the best way to check is to run two reparametrization modules, or two KLD modules together, with the same inputs and to compare them - that would give a decisive way of debugging!

zencoding commented 9 years ago

Tried few changes in the parameters, still no luck. Interestingly the old code (with log(sigma^2) works fine so I guess this has to do with some numerical instability with changed precision. I am working out the Math part of it before getting back to code, will keep you guys updated.

AjayTalati commented 9 years ago

I think a good idea, (if you want to extend this code to more difficult problems), is to break up Joosts, nn.sequential stack of modules into separate encoder, sampler and decoder gModules by using nngraph. That's what I've done.

If you do this then you'll understand the maths as well because you'll see where the grads of the the different loss functions back propagate. Also I'd break Joosts objective function up into forward and backward sections.

Doing that will help you if you want to code up a LSTM VAE. The way this code is now, it's really specialized just for this particular version of VAEs

zencoding commented 9 years ago

Cool...I will try that, do you have any samples to get started or I can browse online for samples to do it with nngraph

y0ast commented 9 years ago

So I had some time this week and created an nngraph version: https://github.com/y0ast/VAE-Torch/tree/nngraph

It works (seems to learn faster with stable KLD value) but parts of the code are a bit rough around the edges. I need to find out how to wrap the gModules nicer. Maybe @AjayTalati has an idea?

AjayTalati commented 9 years ago

Look's good Joost - glad it works! I'll have a think and get back to you later. Sorry, I'm very busy at the moment.