jbornschein / draw

Reimplementation of DRAW
MIT License
347 stars 86 forks source link

Exploding cell value outputs of the encoder? #3

Closed AjayTalati closed 9 years ago

AjayTalati commented 9 years ago

Hi, thanks for making this great implementation open source.

I working on a similar implementation in torch, (but it's not working yet either without or with filterbanks/attention), and I'd like to understand your code better - because I'm confused?

To be precise, is there an easy way using Blocks to track the norms of the encoder cell values, say ||c_enc_t||_2, after the end of the forward pass, of the SGD step?

What I'm finding with my implementation is that, ||c_enc_t||_2, when T gets bigger than about 10 glimpses gets really big, after about 20 epochs. Initially I was using T=64, but I reduced that to T=10, but it's still happening?

Just wondered if you saw this with your Blocks/Theano implementation? Thanks for your help.

Best, Aj

udibr commented 9 years ago

It sounds to me like weight initialization problem, in any case the gradient is clipped to 10 on line 158 of train-draw.py

jbornschein commented 9 years ago

It's easy to track individual variables in blocks -- the only inconvenience here is that many of the interesting variables are "hidden" inside the DrawModel object because I chose to encapsulate the whole logic in there.

What you have to do to monitor some specific quantity:

AjayTalati commented 9 years ago

Dear Jorg,

thanks a lot for taking the time to give a helpful reply.

I'm still going through the Blocks code, (I'm not familiar with it so its taking a while, as soon as I understand it, I'll do what you suggest).

I'm trying an experiment with my Torch implementation, which seems to have fixed my exploding cell values problem. To be exact I applied the batch normalization transform, to the cell values and outputs of the encoder and decoder

c_enc_t , h_enc_t , c_dec_t and h_dec_t

and also x_hat_t, and r_t & w_t where at the moment I'm testing my implementation without attention, so simply using eqns (17) and (18). It seems to have worked, and the losses so far in training are much more stable, and reduce much faster :)

Probably the most encouraging thing that I'm seeing though is that, while my reconstruction loss comes down dramatically, I'm seeing my latent loss KLD steadily increase - which gives me some hope that I'm doing the right things, and the codes correct :) I should be able to reproduce a plot similar to your KL divergence plotted over inference iterations and epochs plot later on?

I just wondered what type of results I should expect with small datasets, and without using the attention mechanisms? So the left plot in the screen shot is the final canvas C_T after a forward training pass through the system, eqns 3-8, and the right plot is the corresponding input image x presented to the read module. I just wondered if you tried to produce these plots while you were training, and if so did they get better as the epochs proceeded?

small_data_set_training

I'm still waiting for the system to finish training then I'll see what sort of quality the stochastic generated canvases are like, as in section 2.3?

Have you got any advice on fast training configurations? Basically what minimal configurations, (e.g. train_set size, z_size, number of iterations, rnn_size, etc ) would be useful for code development in this project - I'm still not sure I haven’t made a fundamental error in my implementation?

Maybe if you have time you might want to try applying, batch normalization, the torch code for it is here. I guess that you could code this up fairly easily if there's not already a python/theano implementation somewhere?

Best regards,

Aj

PS - my default train parameters for fast development/testing of my implementation without attention are,

train_set = first 3000 digits of MNIST 28 binarized, i.e. 5% of the full dataset minibatch_size = 100, number_of_epochs = 100,

number of glimpses T = 16, z_size = 100 number of LSTM cells = 256

initial parameters weights are uniform [-0.08,0.08] grads_clipped in range [-10,10]

AjayTalati commented 9 years ago

Hi,

I think I understand most of your code, expect your KL calculation on lines 80 to 85. I just can't derive it from eqn 11, of the paper, after breaking it up into single time steps?

Any chance of an explanation?

In particular I don't understand why you've used,

self.prior_log_sigma = 0 self.prior_mean = 0

it seems that makes,

tensor.exp(2 * self.prior_log_sigma) = 1 ??

and line 80 looks like it should be - log_sigma^2

Sorry if the answer is obvious?

jbornschein commented 9 years ago

Yes, you are right, setting prior_log_sigma = 0 makes the denominator equal to 1. With prior_mu = 0 lines 80-85 therefore effectively read:

-log_sigma + 0.5*( exp(log_sigma)**2  + mean**2) - 0.5

which, when you consider log_sigma == 1/2 * log ( exp(log_sigma)**2 ), makes the whole expression equivalent to their eqn (11).

At least I can't spot any mistake when I look at it right now.

AjayTalati commented 9 years ago

Apologies Jorg, I didn’t see the equivalence,

log_sigma == 1/2 * log ( exp(log_sigma)**2 )

It looks right to me now, as well. Thanks a lot for the kind explanation :+1: Your codes very elegant much better than mine!

jbornschein commented 9 years ago

Thanks! :)

I did not systematically perform experiments with the attention-less version on MNIST. The few experiments I ran did not converge to very good NLL bounds -- but much better than the one you described. IIRC my best experiments reached about 100 nats. But I did not spend much time optimizing hyper parameters.

Your reconstruction error looks about right for a untrained model. But with such a small training-set I would expect that you rapidly overfit. 16 glimpses should also be enough for the attention-less model. I remember that I could train models with only very few (~4) iterations. Note that it should even work with only a single iteration -- the model then degrades to a normal VAE with 'funny' encoder and decoder networks.

I did not try batch normalization yet. @vdumoulin is currently preparing a Blocks implementation https://github.com/bartvm/blocks/pull/513 .

AjayTalati commented 9 years ago

Thanks a lot for the great help/insight :+1:

Note that it should even work with only a single iteration -- the model then degrades to a normal VAE with 'funny' encoder and decoder networks.

The few experiments I ran did not converge to very good NLL bounds -- but much better than the one you described. IIRC my best experiments reached about 100 nats

Those reconstruction errors in the screenshots were calculated before I realized that I overlooked adding the sigmoid transform to the final canvas matrix C_T, before calculating the BCE.

I've coded the standard MLP variational autoencoder, and experimented with this (full) MNIST dataset - using Adam and small numbers of latent variables z~50, and hidden units~200, it comes down to around 103 nats after 33 epochs.

mlp_vae

So I think I've made some fundamental mistake somewhere with my DRAW implementation, because my numbers are just too far off :( Guess I've got to go back and compare this with my standard VAE implementation - that will be boring :(

I've coded up the attention modules, but it seems silly trying to test them till I've sorted out the basic system.

AjayTalati commented 9 years ago

Thanks for all your great help Jorg :+1:

I've got my code working a bit better now, the generated digits are not very good still, but I know what the problem is,

samples_generated_from_draw_without_attention_after_25_epochs

It was just bug/misunderstanding of the paper causing the cell value infinities - I learnt a lot from your help.

Best regards,

Aj