theislab / chemCPA

Code for "Predicting Cellular Responses to Novel Drug Perturbations at a Single-Cell Resolution", NeurIPS 2022.
https://arxiv.org/abs/2204.13545
MIT License
88 stars 23 forks source link

Adapt Adversarial predictor for large number of classes #22

Closed siboehm closed 2 years ago

siboehm commented 2 years ago

Currently the CPA Adv predictor outputs a probability distribution over all possible drugs, given the basal state. This works well for datasets like Trapnell (188 drugs) but for LINCS (>17K drugs) it is not feasible because there aren't enough samples of each drug.

Therefore this needs to be adapted:

  1. Cluster all LINCS drugs, for example based on their GROVER embedding into a small number (<1000) of classes. Save the cluster assignment to a df, which will be loaded at runtime.
  2. Predictor now predicts the cluster assignment, not the drug. Adjust the class weights in the loss function to adjust for cluster size imbalances. (Also adjust Vanilla CPA Adversarial loss to adjust for imbalances during finetuning).

Different approaches are possible during finetuning:

  1. Train a new predictor, using the standard BCE loss over all drugs. Advantage: definitely removes all information
  2. Keep old predictor, using BCE loss over cluster assignments. Advantage: No need to train new predictor.

The final CPA model is used to predict counterfactuals. For this to work all information about the drug has to have been removed from the latent basal state. This may not fully be the case when we predict just the cluster assignment. Example: If there are potent and less potent drugs in each cluster, then a notion of potency may remain in the latent basal state even though the cluster cannot be predicted anymore. If we use strategy 1) during finetuning we'll definitely ensure "latent basal state drug ambivalence" for the final model. An alternative approach may be to have the Adv Predictor predict the drug embedding directly (using a smaller, <1000 dim drug embedding) and using a cosine distance to the "true" drug embedding as the adversarial loss. This is how BERT models predict words. For now the cluster strategy seems more promising.

siboehm commented 2 years ago

Leon does the clustering, Simon writes the predictor

MxMstrmn commented 2 years ago

A bit unclear to me: how many different clusters are sensible given that we have 17k different drugs?

siboehm commented 2 years ago

Number of cluster is probably less important than the number of samples per cluster. For reference: These are the number of samples for the drugs with the lowest occurrence in Trapnell:

> adata.obs["condition"].value_counts()[-10:]
Rigosertib      984
Luminespib      980
Tozasertib      975
Mocetinostat    949
Alvespimycin    930
AT9283          910
Patupilone      757
Flavopiridol    693
Epothilone      583
YM155           394
Name: condition, dtype: int64

So it might be nice to have >500 samples per cluster in LINCS. I was somewhat surprised to see that LINCS isn't even that big: 840,677 samples in LINCS vs 290,888 samples in Trapnell.

MxMstrmn commented 2 years ago

@siboehm I do not remember, did we merge my cluster assignments already?

siboehm commented 2 years ago

No. You did generate the cluster afaik, but we never merged any code. So far all I've done is port it from a Binary Cross Entropy loss to a Standard Cross Entropy loss, as BCE requires you to create a OHE target vector, which again would take way too long.