mkusner / grammarVAE

Code for the "Grammar Variational Autoencoder" https://arxiv.org/abs/1703.01925
269 stars 78 forks source link

Replicating the results in pytorch #7

Closed ZmeiGorynych closed 6 years ago

ZmeiGorynych commented 6 years ago

Dear Matt,

after hearing the talk you gave in Cambridge on the Grammar VAE, I thought it would be fun to play with it in pytorch, so I ported your code to pytorch/Python 3, now at https://github.com/ZmeiGorynych/grammarVAE_pytorch

However, I have some questions when trying to replicate the calibration: I use Adam optimizer with lr =5e-4, decreasing to 1e-4 on plateaus, the loss function is like this

BCE = seq_len * self.bce_loss(model_out_x, target_x)
KLD_element = (1 + log_var - mu*mu - log_var.exp())
KLD = -0.5 * torch.mean(KLD_element)
loss = BCE + KLD          

and the encoder/decoder (settings here), which is as best I can tell an exact replica of your functions at model_zinc.py. I'm using batch size 200 as that is the most that'll fit on a p2.xlarge in my implementation of the network.

Now you seem to be calibrating for 100 epochs, which would be 125000 batches for me. However, when I train with the above parameters, doing one validation batch after every 10 train batches, I get the following loss values (x value is batches): image In other words, the loss saturates at the value of about 1.8 after a couple of epochs, and stays there.

Now when I put in a lot of dropout, turn off the sampling of z (just take the mean instead), and replace the KL term with a simple deviation of z batch mean and covariance matrix from those of N(0,1), the model trains much better, getting to loss 0.5 or so over the same period as in the figure above. image

Any idea what I could be doing wrong? Should I just de-weight the KL term further until it works?

Thanks a lot for any suggestions, E.

mkusner commented 6 years ago

Oh awesome! Thanks for doing this! Glad to hear you're interested in the project!

One immediate difference, unless I missed something, is that we down-weight the variance of the encoder by 0.01 shown here (epsilon_std is defined a few lines before to be 0.01): https://github.com/mkusner/grammarVAE/blob/master/models/model_zinc.py#L98 Following the CVAE code by Max Hodak: https://github.com/maxhodak/keras-molecules/blob/master/molecules/model.py#L68

We believe this has the effect of further downweighting the KL-term, which we believe is useful to prevent the encoder from immediately matching the prior. I think that your addition of dropout also likely had a similar effect by preventing the encoder from learning as quickly! There are definitely much more elegant ways to do what we did, such as the nice recent work by @nowozinmsr ! http://www.nowozin.net/sebastian/papers/nowozin2018jvi.pdf

If you want a super quick fix try just down-weighting the variance, but for a more principled fix I would go with the above work!

On Wed, Mar 21, 2018 at 10:12 AM, Egor Kraev notifications@github.com wrote:

Dear Matt,

after hearing the talk you gave in Cambridge on the Grammar VAE, I thought it would be fun to play with it in pytorch, so I ported your code to pytorch/Python 3, now at https://github.com/ZmeiGorynych/grammarVAE_ pytorch

However, I have some questions when trying to replicate the calibration: I use Adam optimizer with lr =5e-4, decreasing to 1e-4 on plateaus, the loss function https://github.com/ZmeiGorynych/grammarVAE_pytorch/blob/master/models/model_loss.py is like this

BCE = seq_len self.bce_loss(model_out_x, target_x) KLD_element = (1 + log_var - mumu - log_var.exp()) KLD = -0.5 * torch.mean(KLD_element) loss = BCE + KLD

and the encoder/decoder https://github.com/ZmeiGorynych/grammarVAE_pytorch/blob/master/models/model_grammar_pytorch.py (settings here https://github.com/ZmeiGorynych/grammarVAE_pytorch/blob/master/models/model_settings.py), which is as best I can tell an exact replica of your functions at model_zinc.py. I'm using batch size 200 as that is the most that'll fit on a p2.xlarge in my implementation of the network.

Now you seem to be calibrating for 100 epochs, which would be 125000 batches for me. However, when I train with the above parameters, doing one validation batch after every 10 train batches, I get the following loss values (x value is batches): [image: image] https://user-images.githubusercontent.com/22304254/37703901-2191f8da-2cef-11e8-83f1-eb286b8f81a6.png In other words, the loss saturates at the value of about 1.8 after a couple of epochs, and stays there.

Now when I put in a lot of dropout, turn off the sampling of z (just take the mean instead), and replace the KL term with a simple deviation of z batch mean and covariance matrix from those of N(0,1), the model trains much better, getting to loss 0.5 or so over the same period as in the figure above. [image: image] https://user-images.githubusercontent.com/22304254/37704135-f144c0e4-2cef-11e8-9978-8fd1362d56ab.png

Any idea what I could be doing wrong? Should I just de-weight the KL term further until it works?

Thanks a lot for any suggestions, E.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/mkusner/grammarVAE/issues/7, or mute the thread https://github.com/notifications/unsubscribe-auth/AIJS0SFlgnqsz5EeWD1D5Nz-UV2BA4tkks5tgidxgaJpZM4SzS0z .

ZmeiGorynych commented 6 years ago

Ah, in the meanwhile I just multiplied the KL term by 0.01, and that seems to work like a charm :) image

I'll take a look at the reference, thanks!

The other thing I'm looking at right now is to simultaneously training (a simple dense layer on top of) the encoder to determine whether a given SMILES string is valid - curious to see whether what effect that will have on speed of training and fraction of valid molecules. Will let you know once I get results.

mkusner commented 6 years ago

Great!

That sounds cool!! I'm interested to hear the results!

On Wed, Mar 21, 2018 at 12:04 PM, Egor Kraev notifications@github.com wrote:

Ah, in the meanwhile I just multiplied the KL term by 0.01, and that seems to work like a charm :) [image: image] https://user-images.githubusercontent.com/22304254/37708679-832e746e-2cff-11e8-91eb-af1f8bc8690c.png

I'll take a look at the reference, thanks!

The other thing I'm looking at right now is to simultaneously training (a simple dense layer on top of) the encoder to determine whether a given SMILES string is valid - curious to see whether what effect that will have on speed of training and fraction of valid molecules. Will let you know once I get results.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/mkusner/grammarVAE/issues/7#issuecomment-374914795, or mute the thread https://github.com/notifications/unsubscribe-auth/AIJS0eUCt9U9is5dV93ZNbPDrA3O8pAUks5tgkHVgaJpZM4SzS0z .

mkusner commented 6 years ago

I'll assume this is closed. If not I'm happy to open again!