carpenter-singh-lab / 2024_vanDijk_CytoSummaryNet

0 stars 0 forks source link

01. First model for CPJUMP1 compound plates #1

Open EchteRobert opened 2 years ago

EchteRobert commented 2 years ago

This first line of experiments will be designed to beat the PR baseline of profiles created with the current aggregation method (which takes the mean). All information about the compound plates used can be found here: https://github.com/jump-cellpainting/2021_Chandrasekaran_submitted.

EchteRobert commented 2 years ago

Calculating the baseline

The baseline in terms of percent replicating (PR) to beat, can be calculated in a variety of ways. Replicates are considered same compound perturbations. Methods that have been tried to calculate PR metrics:

  1. according to a 2 plate train/2 plate val split used in the earliest model iterations
  2. according to the 80/20 train/val split used as in the model
  3. according to the default all 4 plates

Note that the train/val split methods are used to create fair comparison between data the model has and has not seen. This does introduce some problems, so in the future a complete hold-out dataset will be used for fair comparison. The second method will not be shown here, as this method is not very compatible with calculating the PR for various reasons.

Negcons are removed from the dataset both during training and evaluation with PR.

2 plate training (BR00117010, BR00117011) and 2 plate validation (BR00117012, BR00117013)

PR is calculated here with nr_replicates = 3. From top to bottom:

  1. Training PR - mean aggregated profiles, no feature selection
  2. Validation PR - mean aggregated profiles, no feature selection BENCHMARK_TVsplit_noNegcon_pertIname_nR3_PR

default PR (as reported on the CPJUMP1 GitHub)

PR is calculated here with nr_replicates = 6:

  1. All wells PR - mean aggregated profiles, no feature selection pertIname_noNegcon_nR6_PR
EchteRobert commented 2 years ago

Experiment:

The model architecture consists of 3 simple layers followed by a max pooling operation, which transforms the single-cell features into an aggregated profile in a different feature space. A projection head on top of this feature space projects the aggregated features into a different space. This final representation is then used to calculate the loss. The first model that is shown to overfit the data is trained on the 2 plate train/ 2 plate val split of the data. The loss is the SupCon loss.

Main Takeaway:

From this we can deduce that the loss function (inversely) correlates nicely with the PR.

Loss curves

TrainValLoss_overfit

The corresponding PR scores TVsplit_MLP_noNegcon_pertIname_nR3_PR

EchteRobert commented 2 years ago

Experiment

In order to get the model to generalize to the feature aggregation problem, instead of memorizing the training set, I added more data to the training. I did this in two ways:

  1. I aggregated all well data, sorted it according to compound names and then split the data into 80% training and 20% validation. This means the validation set will contain compounds the training set has not seen before.
  2. Up until now nr_cells, the number of cells I sampled from each well, was set to 400. There are wells that contain less and more (see histogram below). I now changed that to sample multiple times from wells which contained more than 400 cells. More specifically, from each well I sampled: once if less than nr_cells, twice if less than 2nr_cells, but more than nr_cells, or thrice if more than 2nr_cells.

This lead to the first generalized model, which, after some tuning, showed loss curves like shown below.

I calculated the PR both for 2 plate train/val split and for the entire dataset (all plates). I only show the PR of all plates here. I will use the model labels as shown in the legend of the loss curves.

Main takeaways

  1. Using the methods described above, results in a lower training and validation loss, as expected.
  2. Seeing all compound types during training ('semiGeneralized' model) significantly increases PR of a model's profiles while the corresponding validation loss can be similar to other models, which achieve lower PR. Although expected, it is good to note that this can happen.
  3. As the model is tuned to achieve lower training and validation losses the PR seems to decrease. Unexpected result as it contradicts what was found in the previous experiment. The only difference I can think of is that the model is trained on positive pairs which mostly come from the same well (but different cells), while in this PR comparison we are only looking at positive pairs between different wells (but the same compound). The difference in performance between the two may then be caused by the model not learning the representations as well as hoped OR the fact that the random subsample of nr_cells (e.g. 400) cells is not representative enough in some cases.

NrCellsHist

Important note here, the 'semiGeneralized model' did not have its data sorted before splitting it into train/val. That means the model has seen all compound types during training, but has not seen them from every plate:

TrainValLoss_Generalized

SemiGeneralized AllPlates_semiGeneralized_PR

Generalized AllPlates_Generalized_PR

Generalized and Tuned AllPlates_generalizedTuned_PR

EchteRobert commented 2 years ago

Experiment

To test if this unexpected correlation between the loss and PR is correlated with the way the model is trained, especially how the batches are organized, I trained the same model "Generalized and Tuned" but with a batch size of 32 and 64 instead of 16, which was used above.

Main takeaways

I will investigate how to stabilize the influence of the batch size now.

Loss curves of the two models TrainValLoss_all

PR of the model with BS32 AllPlates_generalizedTuned_BS32_PR

PR of the model with BS64 AllPlates_generalizedTuned_BS64_PR

EchteRobert commented 2 years ago

Note to self: currently still using the representations in the SupCon loss space for downstream analysis (PR). I tried using the representations created before the projection head, but that resulted in worse performance. My hypothesis is that this is because the projection head and encoder parts of the model are of similar size. I expect that the encoder part needs to be some times larger than the projection head for it to learn a better representation. For example, in SupCon loss paper a feature embedding of size 2048 is used, while the representations are reduced to 128 dimensions in the SupCon loss space.