snap-stanford / GEARS

GEARS is a geometric deep learning model that predicts outcomes of novel multi-gene perturbations
MIT License
192 stars 38 forks source link

no_test option for data split yields error #29

Closed ekernf01 closed 9 months ago

ekernf01 commented 11 months ago

Hi Yusuf et al., congratulations on publishing GEARS -- the paper is excellent and impressively thorough.

Can GEARS use all of the input data for training or validation, leaving nothing for a test set? Running as recommended seems to create a train-val-test split with 25% in the test fold. I am using GEARS in a setting where the test data have already been set aside, and foregoing another 25% of the remaining data could potentially make a big difference to performance. I notice there is an option 'no_test' implemented in data_utils L174, but I get an error when I use it -- full example below. I am using gears version 0.0.4.

Example code:

from gears import PertData, GEARS
dataset_name = 'adamson'
pert_data = PertData('./data', default_pert_graph=False)
pert_data.load(data_name = dataset_name)
pert_data.prepare_split(split = 'no_test', seed = 5) 
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) 

The error:

File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/gears/pertdata.py", line 308, in get_dataloader
for p in self.set2conditions[i]:
    KeyError: 'val'
weizhiting commented 11 months ago

I have the same issue.

yhr91 commented 11 months ago

Thanks very much for your feedback @ekernf01! Thanks also for pointing out issues in the code.

I've fixed the no_test split issue in the repo. I've also added a custom split which a few others had also asked for. It takes a split dictionary as input.

I will update the pip package soon

ekernf01 commented 11 months ago

Thank you! I installed the new version and I can see the new arg for "split_dict_path" and the "custom" option, but I still get the same KeyError when I run the code above.

yhr91 commented 11 months ago

Yes, that's because I hadn't updated the pip package. It's now been update to v0.1.0 and hopefully should work fine. Also includes preprocessed dataloaders for 2 additional datasets now

ekernf01 commented 10 months ago

I have run into an edge case that seems almost the same as this bug, but only happens on certain inputs. Here's an example where it's fine on Dixit but fails on a subset containing about 2/3 of the cells and 1/2 of the conditions.

from gears import PertData
dataset_name = 'dixit'
trainset_perts = [
    "ctrl"         ,
    "ELK1+ctrl"    ,
    "ELF1+ctrl"    ,
    "CREB1+ctrl"   ,
    "EGR1+ctrl"    ,
    "YY1+ctrl"     ,
    "NR2C2+ctrl"   ,
    "GABPA+ctrl"   ,
    "RACGAP1+ctrl" ,
    "TOR1AIP1+ctrl",
]
pert_data = PertData('./data', default_pert_graph=False)
pert_data.load("dixit") 
dixit = pert_data.adata
dixit = dixit[dixit.obs["condition"].isin(trainset_perts), :]
pert_data = PertData('./data', default_pert_graph=False)
pert_data.new_data_process(dataset_name = 'current', adata = dixit)   
pert_data.prepare_split(split = 'no_test', seed = 5) 
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) 

Error:

Traceback (most recent call last):
  File "/home/ekernf01/Downloads/splitter_bug.py", line 23, in <module>
    pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) 
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/gears/pertdata.py", line 457, in get_dataloader
    for p in self.set2conditions[i]:
KeyError: 'val'
yhr91 commented 9 months ago

Sorry for the late response, there was a bug in the implementation that's now been fixed in the latest version (0.1.2). Let me know if you have any more questions. Thanks