snap-stanford / GEARS

GEARS is a geometric deep learning model that predicts outcomes of novel multi-gene perturbations
MIT License
200 stars 39 forks source link

Issue with new_data_process function #45

Closed carversh closed 8 months ago

carversh commented 8 months ago

Hi,

I am trying to input my own data into GEARS, and am encountering an issue after I thought I formatted my scanpy object correctly. I also tried adding the ensembl id as the index to the .var dataframe, however this still triggered the same error. Any solutions to this? FYI I deleted my .raw file because I couldn't save my h5ad file if it wasn't deleted. The formatting in the .raw file is different from the formatting you require in the .obs and .var dataframes.

Here is the line of code that is triggering an error:

pert_data.new_data_process(dataset_name = 'my_data', adata = adata_final) # specific dataset name and adata object

Here is the error being triggered:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[140], [line 2](vscode-notebook-cell:?execution_count=140&line=2)
      [1](vscode-notebook-cell:?execution_count=140&line=1) pert_data = PertData('[./data](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/data)') # specific saved folder
----> [2](vscode-notebook-cell:?execution_count=140&line=2) pert_data.new_data_process(dataset_name = 'my_data', adata = adata_final) # specific dataset name and adata object
      [3](vscode-notebook-cell:?execution_count=140&line=3) # pert_data.load(data_path = './data/my_data') # load the processed data, the path is saved folder + dataset_name
      [4](vscode-notebook-cell:?execution_count=140&line=4) # pert_data.prepare_split(split = 'simulation', seed = 1) # get data split with seed
      [5](vscode-notebook-cell:?execution_count=140&line=5) # pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader

File [~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/pertdata.py:250](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/pertdata.py:250), in PertData.new_data_process(self, dataset_name, adata, skip_calc_de)
    [248](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/pertdata.py:248)     os.mkdir(save_data_folder)
    [249](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/pertdata.py:249) self.dataset_path = save_data_folder
--> [250](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/pertdata.py:250) self.adata = get_DE_genes(adata, skip_calc_de)
    [251](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/pertdata.py:251) if not skip_calc_de:
    [252](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/pertdata.py:252)     self.adata = get_dropout_non_zero_genes(self.adata)

File [~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:64](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:64), in get_DE_genes(adata, skip_calc_de)
     [62](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:62) adata.obs = adata.obs.astype('category')
     [63](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:63) if not skip_calc_de:
---> [64](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:64)     rank_genes_groups_by_cov(adata, 
     [65](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:65)                      groupby='condition_name', 
     [66](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:66)                      covariate='cell_type', 
     [67](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:67)                      control_group='ctrl_1', 
     [68](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:68)                      n_genes=len(adata.var),
     [69](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:69)                      key_added = 'rank_genes_groups_cov_all')
...
    [109](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/scanpy/tools/_rank_genes_groups.py:109)     )
    [111](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/scanpy/tools/_rank_genes_groups.py:111) adata_comp = adata
    [112](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/scanpy/tools/_rank_genes_groups.py:112) if layer is not None:

ValueError: Could not calculate statistics for groups ct1_DLG5-AS1+PPM1D_1+1, ct1_IRF2+TIRAP_1+1,...

Here is the structure of my count matrix:

AnnData object with n_obs × n_vars = 15744 × 5000
    obs: 'condition', 'cell_type'
    var: 'gene_name'
    uns: 'hvg'

Here is the structure of adata_final.var:

     gene_name
0   AL669831.5
1    LINC00115
2       FAM41C
10        HES4
11       ISG15

Here is the structure of adata_final.obs:

                         condition   cell_type
AAACCCAAGTGGCAGT-1_1  SBDS+TNFSF15  ct1
AAACGAACAATGTCTG-1_1     NIPBL-AS1  ct1
AAACGAATCCTCTCTT-1_1          ctrl  ct1
AAACGCTCAAGTTCCA-1_1   ctrl+RAD54B  ct1
AAACGCTCAGGTCAAG-1_1      HLA-DRB6  ct1
carversh commented 8 months ago

I realize it's because many samples only have a single instance.

carversh commented 8 months ago

Actually, in order to predict on a new dataset, do I need multiple samples of the same perturbation in my input dataset? I wouldn't think so if I'm just predicting on the already trained model.

yhr91 commented 8 months ago

Yes, having only a single sample of a specific perturbation type does throw unexpected errors. I believe this is linked to the section of the dataloader that computes differentially expressed genes.

If you are using these perturbations (which have only a single sample) for training or validation then that isn't really recommended anyway. If you want to use them just during inference then it is not needed to include the post-perturbation expression information in the dataloader when training the model. You can just directly predict the perturbation effect for that perturbation.

in order to predict on a new dataset

At the moment, training and prediction is done in the context of the same dataset. We have not designed GEARS for cross-dataset prediction.