carpenter-singh-lab / 2024_vanDijk_PLoS_CytoSummaryNet

0 stars 0 forks source link

General model experiments #3

Open EchteRobert opened 2 years ago

EchteRobert commented 2 years ago

This issue is used to test more general aspects of model development not directly related to, but likely still influenced by, the dataset or model hyperparameters that are used at that time.

EchteRobert commented 2 years ago

Experiment

During training I sample cells from each well to create uniform tensors, this is a requirement for creating batches of data larger than 1. It is generally known in contrastive learning/metric learning that a large batch size is essential for good model performance. This is also something that I showed in the first issue: https://github.com/broadinstitute/FeatureAggregation_single_cell/issues/1. To test if the model's performance is affected by this sampling I evaluate the model's performance using a range of number of sampled cells and also test it without sampling, i.e. collapsing all existing cell features into a profile like with the mean.

Hypothesis

Model performance will be slightly improved by using all existing cells without extra sampling because the sampling may pick out cells that are not as representative of the perturbation.

Main takeaways

From here on evaluation will be done without sampling, simply by collapsing all cells into a feature representation.

Exciting stuff here! _no sampling_ ![MLP_BS64_noSampling_PR](https://user-images.githubusercontent.com/62173977/154756633-5da6068d-b143-4ca4-a204-aaf1377083b2.png) _600 cells sampled_ ![MLP_BS64_600cells_PR](https://user-images.githubusercontent.com/62173977/154756663-5bed0c7d-5f11-4379-9e66-544d066ac970.png) _400 cells sampled_ ![MLP_BS64_400cells_PR](https://user-images.githubusercontent.com/62173977/154756693-2a59f5b0-3c2d-4126-8762-53a61a3cdb04.png) _200 cells sampled_ ![MLP_BS64_200cells_PR](https://user-images.githubusercontent.com/62173977/154756719-0ce736b3-e650-471b-b873-83a1eca2dc86.png) _100 cells sampled_ ![MLP_BS64_100cells_PR](https://user-images.githubusercontent.com/62173977/154756761-33b81db5-9dd8-4673-a11b-01ec4e00bb46.png) _50 cells sampled_ ![MLP_BS64_50cells_PR](https://user-images.githubusercontent.com/62173977/154756790-01a607cc-31a9-4206-8130-1269e4a190a0.png) _cell distribution histogram_ ![NrCellsHist](https://user-images.githubusercontent.com/62173977/154757284-53215f5c-185e-4180-b815-b337fad77148.png)
EchteRobert commented 2 years ago

Model capacity experiment 1

Goal: Test if the current training method (model architecture, optimizer, and loss function) is capable of learning to distinguish point sets from each other while their means are the same.

Background: This experiment is the first one in a series of experiments that aims to figure out what the feature aggregation model is learning. We expect that it is learning to both select cells and features when aggregating single-cell level feature data, however this is hard to verify. So instead, this experiment aims to test the model’s capacity to learn the covariance matrix of the input data.

The general setup is as follows:

To create the data:

Training process:

If it is able to complete this task, we can verify that this architecture is able to learn covariance matrices, and that most likely the feature aggregation model I am proposing is also able to do (or doing) this for the single-cell level feature data.

Main takeaways

Results

Cool ellipses! ![CovarianceClasses](https://user-images.githubusercontent.com/62173977/163246750-bb9f1569-14c9-47f1-ab4b-ce4b91936f27.png)
Model performance Total model mAP: 0.9589285714285715 Total model precision at R: 0.9166666666666666 | | class | AP | precision at R | |---:|-----------:|---------:|-----------------:| | 0 | 0 | 1 | 1 | | 1 | 0 | 1 | 1 | | 2 | 0 | 1 | 1 | | 3 | 0 | 1 | 1 | | 7 | 4 | 1 | 1 | | 8 | 5 | 1 | 1 | | 9 | 5 | 1 | 1 | | 10 | 5 | 1 | 1 | | 11 | 5 | 1 | 1 | | 12 | 7 | 1 | 1 | | 13 | 7 | 1 | 1 | | 14 | 7 | 1 | 1 | | 15 | 7 | 1 | 1 | | 4 | 4 | 0.866667 | 0.666667 | | 5 | 4 | 0.833333 | 0.666667 | | 6 | 4 | 0.642857 | 0.333333 |
Baseline (mean) performance Total baseline (mean) mAP: 0.3086709586709586 Total baseline (mean) precision at R: 0.20833333333333331 | | class | AP | precision at R | |---:|-----------:|---------:|-----------------:| | 11 | 5 | 0.471429 | 0.333333 | | 8 | 5 | 0.465368 | 0.333333 | | 12 | 7 | 0.460317 | 0.333333 | | 2 | 0 | 0.447619 | 0.333333 | | 3 | 0 | 0.447619 | 0.333333 | | 14 | 7 | 0.316667 | 0.333333 | | 13 | 7 | 0.305556 | 0.333333 | | 0 | 0 | 0.280952 | 0.333333 | | 7 | 4 | 0.256614 | 0.333333 | | 4 | 4 | 0.25 | 0 | | 9 | 5 | 0.240909 | 0 | | 15 | 7 | 0.233333 | 0.333333 | | 1 | 0 | 0.206044 | 0 | | 5 | 4 | 0.205026 | 0 | | 10 | 5 | 0.189377 | 0 | | 6 | 4 | 0.161905 | 0 |
EchteRobert commented 2 years ago

Model capacity experiment 2

Continuing on the previous idea of learning higher order (i.e. covariance matrix and higher) statistics with the current model setup, this experiment aims to learn to repeat the previous one but on real data. More specifically, the Stain3 data is used as described here https://github.com/broadinstitute/FeatureAggregation_single_cell/issues/6. I train using plates BR00115134_FS, BR00115125_FS, and BR00115133highexp_FS and validate on the rest.

Note: This experiment is not ideal as I am using the data that is normalized to zero mean and unit variance per feature on the plate level. I then subsequently zero mean the data on the well level as well, meaning that taking the average profile is useless. I used this data because it is faster than preprocessing all of the data again.

I am not too sure what this means for the standard deviation (or covariance matrix) on the well level. What do you think @shntnu ?

Results

Table Time! | plate | Training mAP model | Training mAP BM | Validation mAP model | Validation mAP BM | PR model | PR BM | |:------------------|---------------------:|------------------:|-----------------------:|--------------------:|-----------:|--------:| | _Training plates_ | | | | | | | | **BR00115134** | 0.44 | 0.03 | 0.24 | 0.04 | 90 | 12.2 | | **BR00115125** | 0.37 | 0.03 | 0.25 | 0.04 | 94.4 | 6.7 | | **BR00115133highexp** | 0.38 | 0.02 | 0.17 | 0.02 | 92.2 | 2.2 | | _Validation plates_ | | | | | | | | **BR00115128highexp** | 0.33 | 0.03 | 0.25 | 0.04 | 81.1 | 11.1 | | **BR00115125highexp** | 0.29 | 0.03 | 0.22 | 0.02 | 78.9 | 5.6 | | **BR00115131** | 0.32 | 0.03 | 0.22 | 0.03 | 85.6 | 11.1 | | **BR00115133** | 0.32 | 0.03 | 0.16 | 0.04 | 83.3 | 5.6 | | **BR00115127** | 0.29 | 0.03 | 0.22 | 0.05 | 84.4 | 3.3 | | **BR00115128** | 0.33 | 0.03 | 0.29 | 0.04 | 86.7 | 2.2 | | **BR00115129** | 0.3 | 0.03 | 0.26 | 0.04 | 82.2 | 5.6 | | BR00115126 | 0.2 | 0.03 | 0.22 | 0.04 | 50 | 7.8 |
shntnu commented 2 years ago

I am not too sure what this means for the standard deviation (or covariance matrix) on the well level.

My notes are below

Note: This experiment is not ideal as I am using the data that is normalized to zero mean and unit variance per feature on the plate level.

^^^ This is fine

I then subsequently zero mean the data on the well level as well, meaning that taking the average profile is useless.

Perfect, because you've not scaled to unit variance (otherwise you'd be looking for structure in the correlation matrix, instead of the covariance matrix)

I used this data because it is faster than preprocessing all of the data again.

That worked out well!

I've not peeked into the results but I am eagerly looking forward to your talk tomorrow where you might discuss more 🎉

shntnu commented 2 years ago

Ok, I couldn't contain my excitement so I looked at the results :D

I just picked one at random, and focused only on

plate Training mAP model Training mAP BM Validation mAP model Validation mAP BM PR model PR BM
Validation plates
BR00115128highexp 0.33 0.03 0.25 0.04 81.1 11.1

This is fantastic, right?!

EchteRobert commented 2 years ago

Yes I think so too! I think this is proof that we can learn more than the mean. The mAP of the benchmark looks random (I believe it should be ~1/30, but the math around mAP is not as intuitive to me as the precision at K :) ). Perhaps we can now try fixing the covariance matrix as well and see if we can still learn?

shntnu commented 2 years ago

Perhaps we can now try fixing the covariance matrix as well and see if we can still learn?

I didn't understand – you're already learning the covariance because you only mean subtract wells, right?

EchteRobert commented 2 years ago

Yes you're right. I mean seeing if we can learn third order interactions. Probably easiest if we discuss it tomorrow

shntnu commented 2 years ago

Yes you're right. I mean seeing if we can learn third order interactions. Probably easiest if we discuss it tomorrow

Ah by fixing you mean factoring out – got it

For that, you'd spherize instead of mean subtracting

That will be totally shocking it if works!!

Even second order is pretty awesome (IMO, unless there's something trivial happening here https://broadinstitute.slack.com/archives/C025JFCBQAK/p1650466733918839?thread_ts=1649774854.681729&cid=C025JFCBQAK)

shntnu commented 2 years ago

PS – unless something super trivial is happening here that we haven't caught, I think you might be pretty close to having something you can write up. Let's get together with @johnarevalo for his advice, maybe next week

EchteRobert commented 2 years ago

Worth to give it a shot ;) Sounds good!

shntnu commented 2 years ago

For the toy data, also standardize after rotating and see what happens then. The idea is that we don’t yet know if it is learning covariance or just standard deviation

EchteRobert commented 2 years ago

Repeat of same experiment with standardized feature dimensions

Based on @shntnu's previous comment.

Main takeaways

Total model mAP: 0.93 Total model precision at R: 0.92

Total baseline (mean) mAP: 0.33 Total baseline (mean) precision at R: 0.21

The model is still beating the baseline (random) performance.

Ellipsoid classes after standardizing points ![CovarianceClassesStandardized](https://user-images.githubusercontent.com/62173977/165144692-252bf148-96e2-44f7-ab91-45ff2f4ffcd6.png)
mean Average Precision scores Total model mAP: 0.9282512626262627 Total model precision at R: 0.9166666666666666 | | compound | AP | precision at R | |---:|-----------:|---------:|-----------------:| | 0 | 0 | 1 | 1 | | 1 | 0 | 1 | 1 | | 2 | 0 | 1 | 1 | | 3 | 0 | 1 | 1 | | 4 | 4 | 1 | 1 | | 5 | 4 | 1 | 1 | | 6 | 4 | 1 | 1 | | 7 | 4 | 1 | 1 | | 12 | 7 | 1 | 1 | | 13 | 7 | 1 | 1 | | 15 | 7 | 1 | 1 | | 9 | 5 | 0.916667 | 1 | | 10 | 5 | 0.866667 | 0.666667 | | 11 | 5 | 0.866667 | 1 | | 14 | 7 | 0.757576 | 0.666667 | | 8 | 5 | 0.444444 | 0.333333 | Total baseline (mean) mAP: 0.32838989713989714 Total baseline (mean) precision at R: 0.20833333333333331 | | compound | AP | precision at R | |---:|-----------:|---------:|-----------------:| | 8 | 5 | 0.638889 | 0.666667 | | 14 | 7 | 0.54359 | 0.333333 | | 1 | 0 | 0.474074 | 0.333333 | | 12 | 7 | 0.4 | 0.333333 | | 13 | 7 | 0.383333 | 0.333333 | | 9 | 5 | 0.361111 | 0.333333 | | 11 | 5 | 0.354701 | 0.333333 | | 4 | 4 | 0.347222 | 0.333333 | | 6 | 4 | 0.289683 | 0.333333 | | 10 | 5 | 0.277778 | 0 | | 0 | 0 | 0.244444 | 0 | | 7 | 4 | 0.214286 | 0 | | 3 | 0 | 0.197619 | 0 | | 2 | 0 | 0.188889 | 0 | | 5 | 4 | 0.185606 | 0 | | 15 | 7 | 0.153014 | 0 |
shntnu commented 2 years ago

The model is still beating the baseline (random) performance.

Great!

And I verified, as sanity check, that the baseline hasn't changed (much) from before https://github.com/broadinstitute/FeatureAggregation_single_cell/issues/3#issuecomment-1098357632

The model results don't change much either (correlation vs covariance; details below)

BTW, you show 10 ellipses but you have 16 rows in your results. Why is that?

**Correlation** From the most recent results https://github.com/broadinstitute/FeatureAggregation_single_cell/issues/3#issuecomment-1108862653 Total model mAP: 0.9282512626262627 Total model precision at R: 0.9166666666666666 **Covariance** From the results 12 days ago https://github.com/broadinstitute/FeatureAggregation_single_cell/issues/3#issuecomment-1098357632 Total model mAP: 0.9589285714285715 Total model precision at R: 0.9166666666666666
EchteRobert commented 2 years ago

Great, thanks for checking!

There's 10 classes, but 4 samples (of 800 points each) of each class. The validation set consists of 4 classes, so 16 samples total. I report all samples individually here, normally I take the mean per class (compound).

Interesting to note (perhaps for myself in the future): I had to reduce the learning rate by a factor of 100 (lr: 1e-5) to learn the correlation with the model adequately compared to learning the covariance (lr: 1e-3).

EchteRobert commented 2 years ago

Experiment 3 - Sphering the toy data

In this experiment I sample 800*4 points using each covariance matrix class for the distribution, then I sphere this sample and subsequently subsample it to create pairs for training. I increase the number of epochs from 40 to 1000 as the model is not able to fit the data otherwise.

Main takeaways

Regularization 0.01 - heavy sphering ![Spherize0_01](https://user-images.githubusercontent.com/62173977/165328993-046c3c4f-8fe0-4e34-94eb-943870fe162e.png) Total model mAP: 0.2943837412587413 Total model precision at R: 0.125 Total baseline (mean) mAP: 0.25055043336293337 Total baseline (mean) precision at R: 0.125
full tables **Model** | | compound | AP | precision at R | |---:|-----------:|---------:|-----------------:| | 10 | 5 | 0.535354 | 0.333333 | | 11 | 5 | 0.535354 | 0.333333 | | 13 | 7 | 0.516667 | 0.333333 | | 2 | 0 | 0.293651 | 0.333333 | | 3 | 0 | 0.289683 | 0 | | 4 | 4 | 0.288889 | 0.333333 | | 8 | 5 | 0.268687 | 0 | | 12 | 7 | 0.266667 | 0 | | 9 | 5 | 0.25 | 0 | | 0 | 0 | 0.238095 | 0 | | 1 | 0 | 0.233333 | 0 | | 5 | 4 | 0.22906 | 0.333333 | | 14 | 7 | 0.227273 | 0 | | 15 | 7 | 0.227273 | 0 | | 6 | 4 | 0.165568 | 0 | | 7 | 4 | 0.144589 | 0 | **Baseline** | | compound | AP | precision at R | |---:|-----------:|---------:|-----------------:| | 2 | 0 | 0.455556 | 0.333333 | | 3 | 0 | 0.451282 | 0.333333 | | 5 | 4 | 0.273016 | 0.333333 | | 11 | 5 | 0.255495 | 0 | | 8 | 5 | 0.254701 | 0.333333 | | 13 | 7 | 0.251852 | 0.333333 | | 4 | 4 | 0.240741 | 0 | | 6 | 4 | 0.240741 | 0 | | 9 | 5 | 0.229798 | 0 | | 0 | 0 | 0.225397 | 0.333333 | | 1 | 0 | 0.215812 | 0 | | 14 | 7 | 0.212169 | 0 | | 15 | 7 | 0.199074 | 0 | | 12 | 7 | 0.177778 | 0 | | 10 | 5 | 0.169841 | 0 | | 7 | 4 | 0.155556 | 0 |
Regularization 0.1 - medium sphering ![Spherize0_1](https://user-images.githubusercontent.com/62173977/165329032-38d8a000-0063-4d68-83af-8d8b0c7d18ac.png) Total model mAP: 0.6088789682539683 Total model precision at R: 0.5 Total baseline (mean) mAP: 0.2500837125837126 Total baseline (mean) precision at R: 0.125
full tables **Model** | | compound | AP | precision at R | |---:|-----------:|---------:|-----------------:| | 0 | 0 | 1 | 1 | | 2 | 0 | 1 | 1 | | 3 | 0 | 1 | 1 | | 12 | 7 | 0.791667 | 0.666667 | | 9 | 5 | 0.722222 | 0.666667 | | 1 | 0 | 0.638889 | 0.666667 | | 11 | 5 | 0.638889 | 0.666667 | | 5 | 4 | 0.591667 | 0.333333 | | 8 | 5 | 0.588889 | 0.666667 | | 15 | 7 | 0.569444 | 0.333333 | | 10 | 5 | 0.555556 | 0.666667 | | 14 | 7 | 0.533333 | 0.333333 | | 6 | 4 | 0.341667 | 0 | | 4 | 4 | 0.319444 | 0 | | 13 | 7 | 0.302778 | 0 | | 7 | 4 | 0.147619 | 0 | **Baseline** | | compound | AP | precision at R | |---:|-----------:|---------:|-----------------:| | 2 | 0 | 0.455556 | 0.333333 | | 3 | 0 | 0.455556 | 0.333333 | | 5 | 4 | 0.273016 | 0.333333 | | 8 | 5 | 0.254701 | 0.333333 | | 13 | 7 | 0.251852 | 0.333333 | | 11 | 5 | 0.24359 | 0 | | 6 | 4 | 0.240741 | 0 | | 9 | 5 | 0.229798 | 0 | | 14 | 7 | 0.228836 | 0 | | 0 | 0 | 0.225397 | 0.333333 | | 1 | 0 | 0.220862 | 0 | | 4 | 4 | 0.220539 | 0 | | 15 | 7 | 0.205026 | 0 | | 12 | 7 | 0.177778 | 0 | | 10 | 5 | 0.165568 | 0 | | 7 | 4 | 0.152525 | 0 |
Regularization 0.3 - low sphering ![Spherize0_03](https://user-images.githubusercontent.com/62173977/165329338-18d528fe-c7fa-4e3d-95bb-86e9881f62a5.png) Total model mAP: 0.722172619047619 Total model precision at R: 0.625 Total baseline (mean) mAP: 0.25495106745106744 Total baseline (mean) precision at R: 0.10416666666666666
full tables **Model** | | compound | AP | precision at R | |---:|-----------:|---------:|-----------------:| | 0 | 0 | 1 | 1 | | 1 | 0 | 1 | 1 | | 2 | 0 | 1 | 1 | | 3 | 0 | 1 | 1 | | 11 | 5 | 0.916667 | 0.666667 | | 13 | 7 | 0.916667 | 0.666667 | | 14 | 7 | 0.916667 | 0.666667 | | 12 | 7 | 0.755556 | 0.666667 | | 5 | 4 | 0.7 | 0.333333 | | 9 | 5 | 0.626984 | 0.666667 | | 8 | 5 | 0.622222 | 0.666667 | | 6 | 4 | 0.588889 | 0.666667 | | 7 | 4 | 0.5 | 0.333333 | | 4 | 4 | 0.47619 | 0.333333 | | 15 | 7 | 0.31746 | 0.333333 | | 10 | 5 | 0.21746 | 0 | **Baseline** | | compound | AP | precision at R | |---:|-----------:|---------:|-----------------:| | 2 | 0 | 0.455556 | 0.333333 | | 3 | 0 | 0.455556 | 0.333333 | | 5 | 4 | 0.333333 | 0.333333 | | 13 | 7 | 0.288889 | 0.333333 | | 11 | 5 | 0.24359 | 0 | | 6 | 4 | 0.240741 | 0 | | 1 | 0 | 0.233333 | 0 | | 9 | 5 | 0.229798 | 0 | | 0 | 0 | 0.22906 | 0.333333 | | 8 | 5 | 0.226923 | 0 | | 4 | 4 | 0.212602 | 0 | | 14 | 7 | 0.210606 | 0 | | 15 | 7 | 0.205026 | 0 | | 12 | 7 | 0.184615 | 0 | | 10 | 5 | 0.17033 | 0 | | 7 | 4 | 0.159259 | 0 |
shntnu commented 2 years ago
  • After heavy sphering of the data the model is no longer able to learn how to discern the different classes.

Ah, this is expected (and thus, good!) because your data is almost surely fully explained by its second-order moments - because that's how you generated it – and sphering factors that out.

The story is different with your real data – there, it will almost sure not be fully explained by second-order moments (although that doesn't mean you can do better)

After medium or low sphering the model is still able to beat the baseline, although it requires many more training steps with a smaller learning rate.

Perfect! as expected, and it's great that you quantified it in terms of how much more complicated it is

Note that medium / low sphering show be roughly equivalent to medium / low value for major/minor axis