Closed ayushkarnawat closed 4 years ago
There are a lot of issues with the current implementation. In particular, the new computed weights for the new batch of sampled data is always decreasing (quite substantially) after every iteration. This results in the VAE not learning the right sequences to sample for the next iteration cycle. As such, it gets stuck in a local optima of sequences. Even with weighted random sampling of the softmax predictions, completely random sequences are generated, which results in the dataset not really learning much in terms of the latent representation.
Might want to visualize the latent space as the model (with updated samples to choose from) is training to understand how the encoding z
is changing constantly. In fact, the primary question becomes, is it even learning to sample new sequences that are still close to the original, but not yet seen/computed?
Perhaps the best way is to reproduce the results using the original data. That is to say, we should:
After further testing, it seems that when training the oracle, we can sometimes get negative loss computation, which means that the gradients will not get computed properly. As such, the oracle will not learn to predict the right y
values. This could potentially result in the right set of sequence not getting picked to sample in the next iteration, which could potentially affect results.
Testing the gaussian NLL loss, it seems that for the following predicted \mu
and \var
the computed loss is correct:
import torch
from torch.nn import functional as F
pred = torch.Tensor([[0.80, 0.20],
[1.30, 0.10],
[1.50, 0.15],
[0.20, 0.20],
[0.50, 0.10]])
target = torch.Tensor([0.7, 1.25, 1.4, 0.2, 0.4])
mu = pred[:, 0]
var = F.softplus(pred[:, 1]) + 1e-6
logvar = torch.log(var)
# Mean loss
0.5 * torch.log(torch.Tensor([math.tau])) + 0.5 * torch.mean(logvar) \
+ torch.mean(torch.square(target - mu) / (2 * var))
# tensor([0.7930])
# Summed
0.5 * N * torch.log(torch.Tensor([math.tau])) + 0.5 * torch.sum(logvar) \
+ torch.sum(torch.square(target - mu) / (2 * var))
# tensor([3.9482])
Therefore, the problems potentially lies somewhere else. From various online sources, It seems that the gaussian negative log likelihood is allowed to be negative (i.e. there is no lower bound at 0).
After testing the algorithm on the original dataset provided in the original paper, it seems that the algorithm performs quite well - it is able to find "interesting" points to sample such that it finds a better variant than the ones provided. This is done quite efficiently as well, without running the algo on the whole dataset. Below, we summarize why the algorithm works well on the original dataset, but does quite poorly on our (much smaller) dataset.
z
). This would allow us to sample points more effectively.yt
and yt_gt
more accurately.Trying the dense oracle on the relatively small dataset, the results are still sub-optimal. This is because the oracle predicts very small values for the sequences it is given (i.e. the variant AWGV
gets a score that is lower than the true GT by quite a large margin). This is likely because in the original dataset, many sequences are scored with a fitness score of 0 (or very near 0). As a result, the weights of the oracle model is quite poor. The figure below shows a paired plot between the oracle and the true GT fitness scores.
Each point corresponds to a variant of the 3GB1 sequence from the training dataset. The horizontal axis reports the ground truth fitness values of the sequence and the vertical axis represents the mean prediction by the oracle for the sequence. Note that even within the training distribution, the oracle is unable to differentiate what makes a good sequence (based off a high fitness score) from a bad one. This can be seen as two sequences that have two different GT fitness values are scored the same by the oracle. As stated above, this might be attributed to the fact that there are a lot of sequences in the training dataset that are scored with a fitness values of 0.
One way to remedy this might be to give higher sample weights to the sequences that are low in numbers when training the oracle itself. Alternatively, it might be worth looking into getting rid of all samples which are below a certain threshold (i.e. mean).
Addressing point #2
above, it seems that the due to the nature of the small initial dataset, subsetting the data by only training on the high fitness values does not actually improve the oracle predictions when compared against the true values. In fact, the predictions become worse.
The problem is, if the oracle is always gonna predict the same value (with some little variance), then for different variants/sequences, the same fitness score will be predicted. Therefore, even if we sample new variants (from the generative model), the oracle will always predict the same score, thus not really learning where to sample next well.
If we run the CbAS algorithm on the dataset described by the 2D matrix (between the oracle and the GT) and observe that similar sequences are within the top k
during each of the different runs, we can (likely) conclude that the algorithm was able to find the right location/space even though it does not know the "true" value of the sequence. This would still hold true if the fitness scores for the sequences given might not be better/higher than the ones provided.
This might be the only way to show that the algorithm work on such a small dataset. To be honest, every other hope is lost, I've tried literally all of the above.
The sequences that are predicted to be highest are already contained within the original dataset. Rather, we should remove the sequences from the original dataset the are topk
.
Ok, it seems like all the problems with CbAS are fixed. Further cleanup will be handled later.
What does this PR do?
Implements the Condition by Adaptive Sampling (CbAS) procedure on the GB1 protein fitness dataset.