theislab / chemCPA

Code for "Predicting Cellular Responses to Novel Drug Perturbations at a Single-Cell Resolution", NeurIPS 2022.
https://arxiv.org/abs/2204.13545
MIT License
88 stars 23 forks source link

Check differentially expressed genes #52

Closed siboehm closed 2 years ago

siboehm commented 2 years ago

When training on LINCS (on our new random_split), after the epochs have finished, we run a full evaluate. This runs into an error:

2021-12-28 14:25:33 (INFO): Running the full evaluation (Epoch:300)
Number of different r2 computations: 120Number of different r2 computations: 34158
2021-12-28 14:33:35 (ERROR): Failed after 4:54:16!Traceback (most recent call last):
  File "/home/icb/simon.boehm/miniconda3/envs/chemical_CPA/lib/python3.9/site-packages/sacred/experiment.py", line 312, in run_commandline    return self.run(
  File "/home/icb/simon.boehm/miniconda3/envs/chemical_CPA/lib/python3.9/site-packages/sacred/experiment.py", line 276, in run    run()
  File "/home/icb/simon.boehm/miniconda3/envs/chemical_CPA/lib/python3.9/site-packages/sacred/run.py", line 238, in __call__    self.result = self.main_function(*args)
  File "/home/icb/simon.boehm/miniconda3/envs/chemical_CPA/lib/python3.9/site-packages/sacred/config/captured_function.py", line 42, in captured_function
    result = wrapped(*args, **kwargs)  File "/tmp/de393a80-c152-4a11-8880-4b02eb8b1780/compert/seml_sweep_icb.py", line 321, in train
    return experiment.train()  File "/home/icb/simon.boehm/miniconda3/envs/chemical_CPA/lib/python3.9/site-packages/sacred/config/captured_function.py", line 42, in captured_function    result = wrapped(*args, **kwargs)
  File "/tmp/de393a80-c152-4a11-8880-4b02eb8b1780/compert/seml_sweep_icb.py", line 256, in train    evaluation_stats = evaluate(
  File "/tmp/de393a80-c152-4a11-8880-4b02eb8b1780/compert/train.py", line 322, in evaluate    "ood": evaluate_r2(
  File "/tmp/de393a80-c152-4a11-8880-4b02eb8b1780/compert/train.py", line 241, in evaluate_r2    np.array(dataset.de_genes[cell_drug_dose_comb])
KeyError: 'A375_DMSO_0.1'

I briefly looked at it, and A375_DMSO_0.1 actually doesn't exist, neither does A375_control_0.1 or any variation. Don't quite know what the underlying cause is, but we should fix it.

It's not a super big issue, but it causes all seml runs to fail without having any results recorded inside the MongoDB.

MxMstrmn commented 2 years ago

TLDR:

It does not make sense to run the evaluation for de_genes for DMSO observations, I suggest to check if the observation is a control in the evaluate_r2 method. Alternatively, we could exclude all DMSO observations from our ood set in the random split.


Detailed explanation:

We run the evaluation over two sets of genes. First, we simply look at the reconstruction for all genes. Second, we check the reconstruction for the genes that are especially relevant for the respective perturbation, that is drugs, of the observation at hand. Check this code and subsequent lines: https://github.com/theislab/chemical_CPA/blob/4c280faaa26dc396184647f3733c365fae6127ff/compert/train.py#L221

These seond set of genes , indicated by _de, are the so called de_genes, which stands for differentially expressed genes. These genes can only be computed between two conditions, in our case drug and control, where control refers to those observations treated with DMSO. For this reason, it is not surprising that we do not store any set of genes for DMSO and consequently the occurring error.

The de_genes attribute is itself a dict of the condition and the corresponding de genes. They are assigned from the adata.uns here: https://github.com/theislab/chemical_CPA/blob/4c280faaa26dc396184647f3733c365fae6127ff/compert/data.py#L112

The adata.uns attributed is assigned in the preprocessing/lincs.ipynb notebook from cell 36 onwards. codeblock from notebook:

%%time
from tqdm.notebook import tqdm
import numpy as np 

de_genes = {}
de_genes_quick = {}

adata_df = adata.to_df()
adata_df['condition'] = adata.obs.condition
dmso = adata_df[adata_df.condition == "DMSO"].mean()

for cond, df in tqdm(adata_df.groupby('condition')): 
    if cond != 'DMSO':
        drug_mean = df.mean()
        de_50_idx = np.argsort(abs(drug_mean-dmso))[-50:]
        de_genes_quick[cond] = drug_mean.index[de_50_idx].values

if full: 
    de_genes = de_genes_quick

else:
    sc.tl.rank_genes_groups(
        adata,
        groupby='condition', 
        reference='DMSO',
        rankby_abs=True,
        n_genes=50
    )
    for cond in tqdm(np.unique(adata.obs['condition'])):
        if cond != 'DMSO':
            df = sc.get.rank_genes_groups_df(adata, group=cond)  # this takes a while
            de_genes[cond] = df['names'][:50].values

I checked this commit 1e21ddff3a7b180342f85a7a6229a2b132fd3030 and the updated lincs_SMILES.ipynb notebook. Judging by the cell counters of the notebook - cells 33 and 34 - it seems that we checked the cell_drug_dose_comb which are the defined pert_categories in the experiment yaml file, see the same commit. : yaml config:

dataset.data_params.pert_category: cov_drug_dose_name # stores celltype_drugname_drugdose

notebook:

for i, k in enumerate(adata.obs.cov_drug_dose_name.unique()):
    try: 
        adata.uns['rank_genes_groups_cov'][k]
    except: 
        print(f"{i}: {k}") if 'DMSO' not in k else None

https://github.com/theislab/chemical_CPA/blob/4c280faaa26dc396184647f3733c365fae6127ff/compert/train.py#L227-L229

siboehm commented 2 years ago

That was a great explanation! Thanks for writing it up. So it seems that for the original splits, Mo / Carlo explicitly excluded DMSO perturbations from the ood split, which is why they didn't run into the error, correct?

I'll write some code later to just skip any DMSO perturbations and rerun the sweep.