oriondollar / TransVAE

A Transformer Based VAE Architecture for De Novo Molecular Design
MIT License
88 stars 23 forks source link

Query on latent representation and length prediction #2

Closed PyeongKim closed 3 years ago

PyeongKim commented 3 years ago

First of all, I was a great pleasant to read such as well-organized and insightful paper of yours. While I am looking through your code, I was a bit curious regarding how you predicted length of sequence. In predict_mask_length method within VAEEncoder class, it reads, "Predicts mask length from latent memory". However, in forward method, the length of sequence was inferred from mu instead of mem. Do I lost something? Or is it a little bug? Thank you Sincerely Kim

` class VAEEncoder(nn.Module):

def __init__(self, layer, N, d_latent, bypass_bottleneck, eps_scale):
    ...
    self.predict_len1 = nn.Linear(d_latent, d_latent*2)
    self.predict_len2 = nn.Linear(d_latent*2, d_latent)
    ...
def predict_mask_length(self, mem):

    ...
    pred_len = self.predict_len1(mem)
    ...

def forward(self, x, mask):
    ...
    pred_len = self.predict_len1(mu)
    ...

def forward_w_attn(self, x, mask):
    ...
    pred_len = self.predict_len1(mu)
    ...

`

oriondollar commented 3 years ago

Hi, I'm glad you found it insightful!

That's not a bug, the variables are named that way to distinguish between the inputs during inference vs. during training. The predict_mask_length function is only called during sampling so the input is a randomly generated vector the size of the latent memory. This is what I refer to as mem. When training, the model learns to predict the length of the SMILES from the latent mean vector, mu, prior to reparameterization. So mem is essentially a randomly sampled mu.