Open EchteRobert opened 2 years ago
This model is trained on 3 plates from Stain2 and Stain4 at the same time and evaluated on Stain2, Stain3, and Stain4.
Training on Stain2 and Stain4 yields similar results to the previous model: it still generalizes to Stain3. However, one of the plates outside of the Stain3 cluster (BR00115126) did not perform as well, showing that there are still some plate effects that are being learned.
This model is trained on 3 plates from Stain2 and Stain3 at the same time and evaluated on Stain2, Stain3, and Stain4.
This model is trained on 2 plates from Stain2, Stain3, and Stain4 and evaluated on all the remaining plates within their clusters.
This model is trained on 3 plates from Stain2, Stain3, and Stain4 and evaluated on all the remaining plates within their clusters.
For a complete discussion of all trained models, see the comment below.
Here I compare all trained models described in the previous comments.
Average rank across metrics | |
---|---|
S3+S4 | 2.92 |
S2+S4 | 2.83 |
S2+S3 | 3.67 |
S2+S3+S4 (6 plates) | 3.58 |
S2+S3+S4 (9plates) | 1.92 |
Individual cluster | 4.92 |
As a final test, to see if increasing the number of training plates increases performance on validation compounds and plates, I train a model with 4 plates from Stain2, Stain3, and Stain4.
Adding this model to the rank analysis from the previous comment, we see that indeed increasing the number of plates increases the average validation mAP. Although there is a bias as the number of plates that serve as training data increase and their validation mAP is also used for these calculations. It's even starting to generalize to the outlier plates in Stain4.
Average mean rank | Average median rank | Average min rank | Average max rank | Average | ||
---|---|---|---|---|---|---|
S3+S4 | 3.67 | 3.67 | 2.67 | 4.00 | 3.50 | |
S2+S4 | 3.67 | 4.33 | 4.33 | 2.67 | 3.75 | |
S2+S3 | 4.67 | 4.33 | 5.33 | 4.00 | 4.58 | |
S2+S3+S4 (6 plates) | 5.33 | 4.33 | 3.67 | 4.67 | 4.50 | |
S2+S3+S4 (9plates) | 3.33 | 2.00 | 2.00 | 3.00 | 2.58 | |
Individual cluster | 6.00 | 5.67 | 6.00 | 6.00 | 5.92 | |
S2+S3+S4 (12 plates) | 1.33 | 1.33 | 2.00 | 2.00 | 1.67 |
To test the influence of which training plates are used on model generalization, I switched up all the training plates and added 3 outlier plates (according to the PC1 loading correlations) as well. I then trained the model in the same way as previous models. Note that comparing the performance of the models is now even harder as the validation plates are completely different.
It appears that, as long as enough training plates are used (i.e. at least 12 here), the model is able to learn a general method of aggregation for different types of analysis pipelines, no matter what training plates are used. Although I do think that using plates from different Stains (which differ quite a lot in terms of feature importances) is beneficial to generalization.
Now that the model is getting consistent results on Stain2, Stain3, and Stain4, I want to do some qualitative analyses to investigate what the model is learning and what it outputs. First up is UMAPs of the model aggregated well profiles of the validation compounds for Stain2, Stain3, and Stain4.
The model ignores batch effects for strong signal compounds and clusters them nicely. Mean aggregation also performs decent clustering for strong signal compounds while ignoring plate effects, however the clusters are much less separated than the model clusters.
In continuation of the previous experiment, I visualized the saliency of cells (i.e. the summed gradients over all features with respect to the SupConLoss over all wells. With this visualization I attempt to visualize how the model is selecting certain cells over others. I visualize 3 compounds here that are poorly profiled by the mean (~0.3 mAP), while they are strongly profiled by the model (~0.9 mAP): sirolimus (red), skepinone-l (green), and purmorphamine (viridis). From each compound I take two wells to visualize.
Perhaps tracing back these cells to the images will give us more insight into what the model is learning.
Here, I show the raw images of a purmorphamine well (M08) in plate BR00112197binned (Stain2). Stain2 only contains 4 images so the FOV is larger than for the other Stain datasets. I use green and red boxes to denote high (>0.8) and low (<0.2) saliency cells. Perhaps in the future I will find a better way of visualizing these cells as the overlay impedes the visual analysis of those cells. I am only showing one FOV here, but it's split into 4 sections for inspection purposes.
The following was outlined by Mehrtash:
To gain more insight into what the model is doing, it would be very useful to "color" several complete FOVs based on the saliency scores and visually inspect them (to begin with). In the least interesting scenario, I suspect that the model might have learned to become a really good QC filter + mean aggregation over the passing cells -- which is still quite interesting, remarkable, and explains why it generalizes to new compounds. Another possibility is that the model might have further learned to pick divergent morphologies (in relevant directions) from the given bunch, come up with a consensus over those, and output the consensus features.
It seems like the model is mostly looking at cells that are clearly separated, while giving less attention to cells in very crowded spaces. This can be seen in all four FOVs shown below. These images are taken from only one well and one compound though so I will need to check other wells and plates to see if this trend persists.
The following experiment was outlined by Mehrtash:
Here's a useful experiment to gain more insight about what the network is doing: take a large number of cells from the same compound (and across several plates) and classify them according to saliency score into two groups -- high: top 20% in saliency, and low: bottom 20% in saliency; throw the middle away. Now, make synthetic inputs to your network with different admixtures of high and low saliency cells, e.g. 0 high + 500 low, 1 high + 499 low, 2 high + 498 low, ..., 499 high + 1 low. 500 high + 0 low, in a deterministic way (e.g. add one high, remove one low, rinse and repeat). Take a PCA of the network output over these 500 inputs and plot the first few PCs vs. admixing fraction, with 0 meaning 0 high + 500 low, and 1 meaning 500 high + 0 low. If you see a "gating" behavior w.r.t. admixing fraction, i.e. the PCs jumping up sharply after a threshold of high saliency cells and quickly stabilizing, then the network has definitely learned to ignore low saliency cells. The noise of the output further sheds light on what the network is doing to the high saliency cells: if the network is simply averaging high saliency cells, you'd expect ~ 1/\sqrt(N) noise in the network output, where N is the number of high saliency cells in the input. If the network is doing feature learning and gating, you'd see a faster scaling, e.g.. 1/N or faster.
I performed this experiment for multiple saliency cut-offs (5, 10, 20, and 40%) and tried different numbers of cells for the admixtures. I eventually settled on using 1000 cells (instead of the 500 mentioned above). Using more cells simply increases the 'resolution' of the figures by creating more datapoints. Note that for this experiment I am using 4 wells from a single plate (instead of multiple). I calculate the X% most salient cells per well and then merge them in one big pool to sample from during the experiment.
I use three types of saliency: gradient, distance (in loss space), and hold one cell out based saliencies, named V1, V2, and V3 respectively. V1 is considered to be more noisy and this measure does not necessarily point to cells that are the most or least representative of a certain profile. I think it rather points to cells whose features are most influential on creating an aggregated profile that is best positioned in the loss space. The exact definition remains hard to interpret and explain. V2 provides a distance measure of how far each single cell in a set is from the aggregated profile (using all cells in a set). Cells further away are considered less salient and cells close by are considered more salient. V3 computes the profile for a well and iteratively leaves one cell out of the set, until you have N profiles for a given well with N cells. Then the supervised contrastive loss is calculated for each of these profiles with respect to the aggregated profiles of all other wells in the plate. This means it has 3 positive pairs and 380 negative pairs. The profiles for which the loss is higher are given a higher saliency and vice versa.
As a sanity check I also performed this experiment using a cut-off of 100%, i.e. just randomly selecting cells. This last experiment should show no changes as a function of the admixing fraction, because there should be little variance captured in the first few PCs (as all profiles should be more similar).
_All of the results below are calculated with 'run-20220505221947-1m1zas58' aka the 'Stain234 12 plates outliers' model.
I have updated the saliency based cell image outlines, they now use square boxes instead of coloring the entire cell. I use either V0 (L1 norm of first activation layer) or V1 (L1 norm of the back propagated gradient by SupConLoss) saliency for the image boxes. I calculated the Pearson correlation between the various saliencies and the CellProfiler features of the input cells. The main idea is to figure out what the saliencies indicate. From visual inspection of the full fov's with V1 saliency overlay, we can see that higher saliency cells tend to be isolated while lower saliency cells tend to lie on top of each other or are in a more crowded space. If this is what the model is generally doing, the features corresponding to isolation should be highly correlated with the V1 saliency.
The model likely gives higher weight to cells which are more isolated, defined by AreaShape, IntegratedIntensity (sum over intensity pixels), and nearest neighbor distances. It also gives more weight to cells with low DNA, RNA, and Mito intensities. In general, these correlations indicate a quality control filter. Isolated cells give better resolution of the cells, while high DNA, RNA and Mito intensities indicate cells that are in the process of cell division.
Below are the mean average precision values for matching sister compounds using the model, baseline, or random shuffling.
Just as a last test, I evaluated the trained model (on 15 plates) on the generated ellipsoid data. If you need a refresher on the experimental setup: https://github.com/broadinstitute/FeatureAggregation_single_cell/issues/3#issuecomment-1098357632 I am still using 2 dimensions to describe the ellipsoids, but I added 1322 empty dimensions to make the input fit into the model. This should be a trivial experiment as the model has already shown that it is able to beat the baseline, and thus is able to learn more than the mean. However, the theory is now that it is applying some form of quality control. If that means it is selecting cells which accurately describe the second moments of the cell set distribution than this task should always be completed perfectly. However, if it is also selecting cells which have a profile close to the mean it will not. It's also possible that the model is actually generating higher order moments from the input data and creating a profile based on that information.
Because I am using only 2 dimensions, I will roll the 2 dimensions over the 1324 available positions to see if this influences the models output. I plot the mAP as a function of the rolled dimensions. Although not exactly, this is an indicator of what features (according to their position) the model is using more than others. Low scores correspond to feature positions that little attention is paid to while the opposite is true for high scores. Moreover, this means that the AreaShape, IntegratedIntensity and Neighbors features are unavailable in some cases.
@shntnu @johnarevalo I wonder what your thoughts are on this. Does this make sense or did I miss something?
Two cluster training data (T: S3+S4)
Some final tweaks to training the model will be made in this issue. All of these tweaks will be made with Stain2, Stain3, and Stain4 in mind at the same time, in stead of 1 at a time. The first model is trained on 3 plates from Stain3 and Stain4 at the same time and evaluated on Stain2, Stain3, and Stain4.
Main takeaways
Table Stain4
| plate | Training mAP model | Training mAP BM | Validation mAP model | Validation mAP BM | PR model | PR BM | |:------------------|---------------------:|------------------:|-----------------------:|--------------------:|-----------:|--------:| | _Training plates_ | | | | | | | | BR00116625highexp | **0.74** | 0.32 | **0.36** | 0.28 | 98.9 | 61.1 | | BR00116628highexp | **0.73** | 0.32 | **0.32** | 0.31 | 98.9 | 57.8 | | BR00116629highexp | **0.78** | 0.29 | **0.35** | 0.29 | 100 | 52.2 | | _Validation plates_ | | | | | | | | BR00116631highexp | **0.47** | 0.28 | 0.27 | **0.3** | 93.3 | 53.3 | | BR00116625 | **0.6** | 0.31 | **0.35** | 0.29 | 98.9 | 58.9 | | BR00116630highexp | **0.52** | 0.29 | **0.3** | 0.3 | 97.8 | 58.9 | | BR00116631 | **0.5** | 0.3 | 0.26 | **0.28** | 94.4 | 57.8 | | BR00116627highexp | **0.55** | 0.31 | **0.38** | 0.27 | 98.9 | 56.7 | | BR00116627 | **0.55** | 0.3 | **0.36** | 0.29 | 96.7 | 56.7 | | BR00116629 | **0.61** | 0.3 | **0.32** | 0.29 | 98.9 | 52.2 | | BR00116628 | **0.58** | 0.32 | 0.28 | **0.29** | 98.9 | 58.9 |Table Stain3
| plate | Training mAP model | Training mAP BM | Validation mAP model | Validation mAP BM | PR model | PR BM | |:------------------|---------------------:|------------------:|-----------------------:|--------------------:|-----------:|--------:| | _Training plates_ | | | | | | | | BR00115134 | **0.75** | 0.37 | **0.42** | 0.33 | 98.9 | 58.9 | | BR00115125 | **0.75** | 0.36 | **0.44** | 0.29 | 98.9 | 54.4 | | BR00115133highexp | **0.76** | 0.38 | **0.38** | 0.31 | 97.8 | 60 | | _Validation plates_ | | | | | | | | BR00115128highexp | **0.52** | 0.4 | **0.42** | 0.33 | 97.8 | 58.9 | | BR00115125highexp | **0.58** | 0.37 | **0.41** | 0.31 | 98.9 | 55.6 | | BR00115131 | **0.54** | 0.38 | **0.44** | 0.29 | 98.9 | 58.9 | | BR00115126 | **0.34** | 0.32 | **0.33** | 0.28 | 57.8 | 53.3 | | BR00115133 | **0.58** | 0.38 | **0.4** | 0.3 | 96.7 | 62.2 | | BR00115127 | **0.56** | 0.38 | **0.47** | 0.31 | 98.9 | 58.9 | | BR00115128 | **0.53** | 0.39 | **0.42** | 0.32 | 96.7 | 61.1 | | BR00115129 | **0.57** | 0.38 | **0.45** | 0.32 | 98.9 | 52.2 |Table Stain2
| plate | Training mAP model | Training mAP BM | Validation mAP model | Validation mAP BM | PR model | PR BM | |:-------------------|---------------------:|------------------:|-----------------------:|--------------------:|-----------:|--------:| | BR00112202 | **0.43** | 0.34 | **0.38** | 0.3 | 88.9 | 54.4 | | BR00112197standard | **0.45** | 0.4 | **0.41** | 0.28 | 85.6 | 56.7 | | BR00112198 | **0.43** | 0.35 | **0.4** | 0.3 | 91.1 | 56.7 | | BR00112197repeat | **0.43** | 0.41 | **0.37** | 0.31 | 81.1 | 63.3 | | BR00112204 | **0.4** | 0.35 | **0.46** | 0.29 | 82.2 | 58.9 | | BR00112197binned | **0.43** | 0.41 | **0.39** | 0.3 | 86.7 | 58.9 | | BR00112201 | **0.47** | 0.4 | **0.41** | 0.32 | 91.1 | 66.7 |