Open EchteRobert opened 2 years ago
The baseline in terms of percent replicating (PR) to beat, can be calculated in a variety of ways. Replicates are considered same compound perturbations. Methods that have been tried to calculate PR metrics:
Note that the train/val split methods are used to create fair comparison between data the model has and has not seen. This does introduce some problems, so in the future a complete hold-out dataset will be used for fair comparison. The second method will not be shown here, as this method is not very compatible with calculating the PR for various reasons.
Negcons are removed from the dataset both during training and evaluation with PR.
PR is calculated here with nr_replicates = 3. From top to bottom:
PR is calculated here with nr_replicates = 6:
The model architecture consists of 3 simple layers followed by a max pooling operation, which transforms the single-cell features into an aggregated profile in a different feature space. A projection head on top of this feature space projects the aggregated features into a different space. This final representation is then used to calculate the loss. The first model that is shown to overfit the data is trained on the 2 plate train/ 2 plate val split of the data. The loss is the SupCon loss.
From this we can deduce that the loss function (inversely) correlates nicely with the PR.
Loss curves
The corresponding PR scores
In order to get the model to generalize to the feature aggregation problem, instead of memorizing the training set, I added more data to the training. I did this in two ways:
nr_cells
, the number of cells I sampled from each well, was set to 400. There are wells that contain less and more (see histogram below). I now changed that to sample multiple times from wells which contained more than 400 cells. More specifically, from each well I sampled: once if less than nr_cells
, twice if less than 2nr_cells
, but more than nr_cells
, or thrice if more than 2nr_cells
.This lead to the first generalized model, which, after some tuning, showed loss curves like shown below.
I calculated the PR both for 2 plate train/val split and for the entire dataset (all plates). I only show the PR of all plates here. I will use the model labels as shown in the legend of the loss curves.
nr_cells
(e.g. 400) cells is not representative enough in some cases.Important note here, the 'semiGeneralized model' did not have its data sorted before splitting it into train/val. That means the model has seen all compound types during training, but has not seen them from every plate:
SemiGeneralized
Generalized
Generalized and Tuned
To test if this unexpected correlation between the loss and PR is correlated with the way the model is trained, especially how the batches are organized, I trained the same model "Generalized and Tuned" but with a batch size of 32 and 64 instead of 16, which was used above.
I will investigate how to stabilize the influence of the batch size now.
Loss curves of the two models
PR of the model with BS32
PR of the model with BS64
Note to self: currently still using the representations in the SupCon loss space for downstream analysis (PR). I tried using the representations created before the projection head, but that resulted in worse performance. My hypothesis is that this is because the projection head and encoder parts of the model are of similar size. I expect that the encoder part needs to be some times larger than the projection head for it to learn a better representation. For example, in SupCon loss paper a feature embedding of size 2048 is used, while the representations are reduced to 128 dimensions in the SupCon loss space.
This first line of experiments will be designed to beat the PR baseline of profiles created with the current aggregation method (which takes the mean). All information about the compound plates used can be found here: https://github.com/jump-cellpainting/2021_Chandrasekaran_submitted.