GT4SD / gt4sd-core

GT4SD, an open-source library to accelerate hypothesis generation in the scientific discovery process.
https://gt4sd.github.io/gt4sd-core/
MIT License
336 stars 74 forks source link

MOSES VAE from Guacamol training reconstruction is "incorrect" #176

Closed davidegraff closed 1 year ago

davidegraff commented 1 year ago

Describe the bug The VAE in GT4SD uses the wrapper of the Moses VAE from Guacamol. Unfortunately, the decoding training step from the Moses VAE is bugged.

More detail The problem arises from the definition of the forward_decoder method:

def forward_decoder(self, x, z):
    lengths = [len(i_x) for i_x in x]

    x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=self.pad)
    x_emb = self.x_emb(x)

    z_0 = z.unsqueeze(1).repeat(1, x_emb.size(1), 1)
    x_input = torch.cat([x_emb, z_0], dim=-1)  # <--- PROBLEM 1
    x_input = nn.utils.rnn.pack_padded_sequence(x_input, lengths, batch_first=True)

    h_0 = self.decoder_lat(z)
    h_0 = h_0.unsqueeze(0).repeat(self.decoder_rnn.num_layers, 1, 1)

    output, _ = self.decoder_rnn(x_input, h_0)

    output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
    y = self.decoder_fc(output)

    recon_loss = F.cross_entropy(  # <--- PROBLEM 2
        y[:, :-1].contiguous().view(-1, y.size(-1)),
        x[:, 1:].contiguous().view(-1),
        ignore_index=self.pad
    )

    return recon_loss

Namely, the reconstruction step is wrong in two spots:

  1. construction of the true input: x_input = torch.cat([x_emb, z_0], dim=-1) In the visual representation of a typical RNN, the true token feeds in from the 'bottom" of the cell and the previous hidden state from the "left". In this implementation, the reparameterized latent vector z is fed in both from the "left" (normal) and the "bottom" (atypical). Fix: this line should be removed
  2. calculation of the reconstruction loss: recon_loss = F.cross_entropy(...) This reconstruction loss is calculated as the per-token loss of the input batch (i.e., the mean of a batch of tokens) because the default reduction in F.cross_entropy is "mean". In turn, this results in reconstruction losses that are very low for the VAE, causing the optimizer to ignore the decoder and focus on the encoder. When a VAE focuses too hard on the encoder, you get mode collapse, and that's what happens with the Moses VAE. Fix: this line should be: F.cross_entropy(..., reduction="sum") / len(x)

To reproduce

  1. Problem 1 is not a "problem" so much as it is highly atypical to structure a VAE like this. I can't say if it results in any actual problems, but it simply shouldn't be there
  2. Problem 2 can be observed with two experiments:
    1. Using PCA with two dimensions, plot the embeddings of a random batch z ~ q(z|x) and a sample from the standard normal distribution z ~ N(0, I). The embeddings from the encoder will look like a point at (0, 0) compared to the samples from the standard normal
    2. Measure the reconstruction accuracy x_r ~ p(x | z ~ q(z | x_0)). In a well-trained VAE, sum(x_r == x_0 for x_0 in xs) / len(xs) should be above 50%. This VAE is generally fairly low (in my experience).
drugilsberg commented 1 year ago

Hi @davidegraff thanks for reporting an issue. Could you please share some details on the bug you found?

For example, how to reproduce it or a permalink to the part of the code that you find problematic? Thanks in advance.

davidegraff commented 1 year ago

Yeah I accidentally submitted too early, I'm editing right now

drugilsberg commented 1 year ago

Cool thanks a lot for reporting this, did you open an issue in the guacamol_baselines repository? If they are not interested in solving this, we could consider implementing a fix in the fork we adapted to install the wrapper here: https://github.com/GT4SD/guacamol_baselines

davidegraff commented 1 year ago

I haven't, but the last commits to either Moses or Guacamol core code was 2+ years ago. I don't really think think they have much interest in fixing this problem, but I could be wrong. Just figured I'd bring this to your attention

drugilsberg commented 1 year ago

Thanks a lot, much appreciated, we will surely look into this.

jannisborn commented 1 year ago

Hi @davidegraff, I agree with your assessment and opened a PR here: https://github.com/GT4SD/moses/pull/3

Can you have a look and confirm? I see the tests passing locally so it seems fine. Once we merge that PR we can bump the dependency here and this should resolve your issue

davidegraff commented 1 year ago

given the VAE change with GT4SD/moses#3, will the available pretrained VAE be updated?

jannisborn commented 1 year ago

If you have the time, this could easily be done also on your side using two commands.

First: gt4sd-saving ... - this will train the model and then save it in such a way that you can use it with gt4sd-inference (basically it will add the model in the ~/.gt4sd folder). Second: gt4sd-upload ... -- This will add the model to the model hub such that it becomes available to all users (it functions similarly to the huggingface model hub). You would have to assign an algorithm name, e.g., v1. For examples, please see the README. We're happy to assist you through the process if needed :)

jannisborn commented 1 year ago

@davidegraff Just prepared a PR with the fix https://github.com/GT4SD/gt4sd-core/pull/178

BTW: I also updated the discovery-demo and added the PaccMann VAE. It's the model from this paper: https://iopscience.iop.org/article/10.1088/2632-2153/abe808/meta

But instead of using it conditionally with a protein sequence, the demo uses it unconditionally. This way, it will behave quite similarly to the MosesVAE, both are autoregressive SMILES VAEs

jannisborn commented 1 year ago

@davidegraff The fix is available on main now and will be included in the next release asap. I retrained the Moses VAE model after the fix and I removed the old model from the model hub. In case you have the moses VAE still locally cached, please remove it from the cache (rm -rf ~/.gt4sd/conditional_generation/MosesGenerator/VaeGenerator/v0). If you run the model afterwards, it will trigger the download of the updated model

Thx again for pointing this out :) Hope it helps you with your future work

jannisborn commented 1 year ago

The new release is out (1.0.4)!

davidegraff commented 1 year ago

thanks for fixing this!