Open EchteRobert opened 2 years ago
During training I sample cells from each well to create uniform tensors, this is a requirement for creating batches of data larger than 1. It is generally known in contrastive learning/metric learning that a large batch size is essential for good model performance. This is also something that I showed in the first issue: https://github.com/broadinstitute/FeatureAggregation_single_cell/issues/1. To test if the model's performance is affected by this sampling I evaluate the model's performance using a range of number of sampled cells and also test it without sampling, i.e. collapsing all existing cell features into a profile like with the mean.
Model performance will be slightly improved by using all existing cells without extra sampling because the sampling may pick out cells that are not as representative of the perturbation.
From here on evaluation will be done without sampling, simply by collapsing all cells into a feature representation.
Goal: Test if the current training method (model architecture, optimizer, and loss function) is capable of learning to distinguish point sets from each other while their means are the same.
Background: This experiment is the first one in a series of experiments that aims to figure out what the feature aggregation model is learning. We expect that it is learning to both select cells and features when aggregating single-cell level feature data, however this is hard to verify. So instead, this experiment aims to test the model’s capacity to learn the covariance matrix of the input data.
The general setup is as follows:
To create the data:
Training process:
model = MLPsumV2(input_dim=2, latent_dim=64, output_dim=2, k=4, dropout=0, cell_layers=1, proj_layers=2, reduction='sum')
nr_train_samples//4
(in this exp. bs=6).If it is able to complete this task, we can verify that this architecture is able to learn covariance matrices, and that most likely the feature aggregation model I am proposing is also able to do (or doing) this for the single-cell level feature data.
Continuing on the previous idea of learning higher order (i.e. covariance matrix and higher) statistics with the current model setup, this experiment aims to learn to repeat the previous one but on real data. More specifically, the Stain3 data is used as described here https://github.com/broadinstitute/FeatureAggregation_single_cell/issues/6. I train using plates BR00115134_FS, BR00115125_FS, and BR00115133highexp_FS and validate on the rest.
Note: This experiment is not ideal as I am using the data that is normalized to zero mean and unit variance per feature on the plate level. I then subsequently zero mean the data on the well level as well, meaning that taking the average profile is useless. I used this data because it is faster than preprocessing all of the data again.
I am not too sure what this means for the standard deviation (or covariance matrix) on the well level. What do you think @shntnu ?
I am not too sure what this means for the standard deviation (or covariance matrix) on the well level.
My notes are below
Note: This experiment is not ideal as I am using the data that is normalized to zero mean and unit variance per feature on the plate level.
^^^ This is fine
I then subsequently zero mean the data on the well level as well, meaning that taking the average profile is useless.
Perfect, because you've not scaled to unit variance (otherwise you'd be looking for structure in the correlation matrix, instead of the covariance matrix)
I used this data because it is faster than preprocessing all of the data again.
That worked out well!
I've not peeked into the results but I am eagerly looking forward to your talk tomorrow where you might discuss more 🎉
Ok, I couldn't contain my excitement so I looked at the results :D
I just picked one at random, and focused only on
plate | Training mAP model | Training mAP BM | Validation mAP model | Validation mAP BM | PR model | PR BM |
---|---|---|---|---|---|---|
Validation plates | ||||||
BR00115128highexp | 0.33 | 0.03 | 0.25 | 0.04 | 81.1 | 11.1 |
This is fantastic, right?!
Yes I think so too! I think this is proof that we can learn more than the mean. The mAP of the benchmark looks random (I believe it should be ~1/30, but the math around mAP is not as intuitive to me as the precision at K :) ). Perhaps we can now try fixing the covariance matrix as well and see if we can still learn?
Perhaps we can now try fixing the covariance matrix as well and see if we can still learn?
I didn't understand – you're already learning the covariance because you only mean subtract wells, right?
Yes you're right. I mean seeing if we can learn third order interactions. Probably easiest if we discuss it tomorrow
Yes you're right. I mean seeing if we can learn third order interactions. Probably easiest if we discuss it tomorrow
Ah by fixing you mean factoring out – got it
For that, you'd spherize instead of mean subtracting
That will be totally shocking it if works!!
Even second order is pretty awesome (IMO, unless there's something trivial happening here https://broadinstitute.slack.com/archives/C025JFCBQAK/p1650466733918839?thread_ts=1649774854.681729&cid=C025JFCBQAK)
PS – unless something super trivial is happening here that we haven't caught, I think you might be pretty close to having something you can write up. Let's get together with @johnarevalo for his advice, maybe next week
Worth to give it a shot ;) Sounds good!
For the toy data, also standardize after rotating and see what happens then. The idea is that we don’t yet know if it is learning covariance or just standard deviation
Based on @shntnu's previous comment.
Total model mAP: 0.93 Total model precision at R: 0.92
Total baseline (mean) mAP: 0.33 Total baseline (mean) precision at R: 0.21
The model is still beating the baseline (random) performance.
The model is still beating the baseline (random) performance.
Great!
And I verified, as sanity check, that the baseline hasn't changed (much) from before https://github.com/broadinstitute/FeatureAggregation_single_cell/issues/3#issuecomment-1098357632
The model results don't change much either (correlation vs covariance; details below)
BTW, you show 10 ellipses but you have 16 rows in your results. Why is that?
Great, thanks for checking!
There's 10 classes, but 4 samples (of 800 points each) of each class. The validation set consists of 4 classes, so 16 samples total. I report all samples individually here, normally I take the mean per class (compound).
Interesting to note (perhaps for myself in the future): I had to reduce the learning rate by a factor of 100 (lr: 1e-5) to learn the correlation with the model adequately compared to learning the covariance (lr: 1e-3).
In this experiment I sample 800*4 points using each covariance matrix class for the distribution, then I sphere this sample and subsequently subsample it to create pairs for training. I increase the number of epochs from 40 to 1000 as the model is not able to fit the data otherwise.
- After heavy sphering of the data the model is no longer able to learn how to discern the different classes.
Ah, this is expected (and thus, good!) because your data is almost surely fully explained by its second-order moments - because that's how you generated it – and sphering factors that out.
The story is different with your real data – there, it will almost sure not be fully explained by second-order moments (although that doesn't mean you can do better)
After medium or low sphering the model is still able to beat the baseline, although it requires many more training steps with a smaller learning rate.
Perfect! as expected, and it's great that you quantified it in terms of how much more complicated it is
Note that medium / low sphering show be roughly equivalent to medium / low value for major/minor axis
This issue is used to test more general aspects of model development not directly related to, but likely still influenced by, the dataset or model hyperparameters that are used at that time.