ayushkarnawat / profit

Exploring evolutionary protein fitness landscapes
MIT License
1 stars 0 forks source link

[WIP] CbAS #104

Closed ayushkarnawat closed 4 years ago

ayushkarnawat commented 4 years ago

What does this PR do?

Implements the Condition by Adaptive Sampling (CbAS) procedure on the GB1 protein fitness dataset.

ayushkarnawat commented 4 years ago

There are a lot of issues with the current implementation. In particular, the new computed weights for the new batch of sampled data is always decreasing (quite substantially) after every iteration. This results in the VAE not learning the right sequences to sample for the next iteration cycle. As such, it gets stuck in a local optima of sequences. Even with weighted random sampling of the softmax predictions, completely random sequences are generated, which results in the dataset not really learning much in terms of the latent representation.

Might want to visualize the latent space as the model (with updated samples to choose from) is training to understand how the encoding z is changing constantly. In fact, the primary question becomes, is it even learning to sample new sequences that are still close to the original, but not yet seen/computed?

Perhaps the best way is to reproduce the results using the original data. That is to say, we should:

  1. Train a generative VAE model on the original data
  2. Train the oracle
  3. Train the ground truth GPR
  4. Optimize using CbAS
ayushkarnawat commented 4 years ago

After further testing, it seems that when training the oracle, we can sometimes get negative loss computation, which means that the gradients will not get computed properly. As such, the oracle will not learn to predict the right y values. This could potentially result in the right set of sequence not getting picked to sample in the next iteration, which could potentially affect results.

ayushkarnawat commented 4 years ago

Testing the gaussian NLL loss, it seems that for the following predicted \mu and \var the computed loss is correct:

import torch
from torch.nn import functional as F

pred = torch.Tensor([[0.80, 0.20], 
                     [1.30, 0.10],
                     [1.50, 0.15],
                     [0.20, 0.20],
                     [0.50, 0.10]])
target = torch.Tensor([0.7, 1.25, 1.4, 0.2, 0.4])

mu = pred[:, 0]
var = F.softplus(pred[:, 1]) + 1e-6
logvar = torch.log(var)

# Mean loss
0.5 * torch.log(torch.Tensor([math.tau])) + 0.5 * torch.mean(logvar) \
            + torch.mean(torch.square(target - mu) / (2 * var)) 
# tensor([0.7930])

# Summed
0.5 * N * torch.log(torch.Tensor([math.tau])) + 0.5 * torch.sum(logvar) \
        + torch.sum(torch.square(target - mu) / (2 * var))
# tensor([3.9482])

Therefore, the problems potentially lies somewhere else. From various online sources, It seems that the gaussian negative log likelihood is allowed to be negative (i.e. there is no lower bound at 0).

ayushkarnawat commented 4 years ago

After testing the algorithm on the original dataset provided in the original paper, it seems that the algorithm performs quite well - it is able to find "interesting" points to sample such that it finds a better variant than the ones provided. This is done quite efficiently as well, without running the algo on the whole dataset. Below, we summarize why the algorithm works well on the original dataset, but does quite poorly on our (much smaller) dataset.

  1. There are a lot of points to train on in the GFP data
    • This affects how the VAE, and subsequently other models are trained since more data usually allows the model to learn a better representation (of the latent space z). This would allow us to sample points more effectively.
    • Also affects both the oracle and GPR predictions since they are each able to learn parameter weights more accurately.
  2. The GFP data also only keeps values greater than the mean threshold. This might help the oracle and GP to the learn decent variants, which will eventually help in predicting the yt and yt_gt more accurately.
  3. The oracle is also a basic dense oracle rather than a more powerful one that is based on i.e. an LSTM.
ayushkarnawat commented 4 years ago

Trying the dense oracle on the relatively small dataset, the results are still sub-optimal. This is because the oracle predicts very small values for the sequences it is given (i.e. the variant AWGV gets a score that is lower than the true GT by quite a large margin). This is likely because in the original dataset, many sequences are scored with a fitness score of 0 (or very near 0). As a result, the weights of the oracle model is quite poor. The figure below shows a paired plot between the oracle and the true GT fitness scores.

true_vs_pred

Each point corresponds to a variant of the 3GB1 sequence from the training dataset. The horizontal axis reports the ground truth fitness values of the sequence and the vertical axis represents the mean prediction by the oracle for the sequence. Note that even within the training distribution, the oracle is unable to differentiate what makes a good sequence (based off a high fitness score) from a bad one. This can be seen as two sequences that have two different GT fitness values are scored the same by the oracle. As stated above, this might be attributed to the fact that there are a lot of sequences in the training dataset that are scored with a fitness values of 0.

One way to remedy this might be to give higher sample weights to the sequences that are low in numbers when training the oracle itself. Alternatively, it might be worth looking into getting rid of all samples which are below a certain threshold (i.e. mean).

ayushkarnawat commented 4 years ago

Addressing point #2 above, it seems that the due to the nature of the small initial dataset, subsetting the data by only training on the high fitness values does not actually improve the oracle predictions when compared against the true values. In fact, the predictions become worse.

The problem is, if the oracle is always gonna predict the same value (with some little variance), then for different variants/sequences, the same fitness score will be predicted. Therefore, even if we sample new variants (from the generative model), the oracle will always predict the same score, thus not really learning where to sample next well.

ayushkarnawat commented 4 years ago

If we run the CbAS algorithm on the dataset described by the 2D matrix (between the oracle and the GT) and observe that similar sequences are within the top k during each of the different runs, we can (likely) conclude that the algorithm was able to find the right location/space even though it does not know the "true" value of the sequence. This would still hold true if the fitness scores for the sequences given might not be better/higher than the ones provided.

This might be the only way to show that the algorithm work on such a small dataset. To be honest, every other hope is lost, I've tried literally all of the above.

ayushkarnawat commented 4 years ago

The sequences that are predicted to be highest are already contained within the original dataset. Rather, we should remove the sequences from the original dataset the are topk.

ayushkarnawat commented 4 years ago

Ok, it seems like all the problems with CbAS are fixed. Further cleanup will be handled later.