theislab / chemCPA

Code for "Predicting Cellular Responses to Novel Drug Perturbations at a Single-Cell Resolution", NeurIPS 2022.
https://arxiv.org/abs/2204.13545
MIT License
88 stars 23 forks source link

Fold-change prediction on novel SMILES #109

Closed michal-pikusa closed 1 year ago

michal-pikusa commented 1 year ago

Hi. Thanks for sharing this exciting method!

I wanted to try it out to predict fold-changes on a set of novel SMILES that I have, using the pretrained rdkit model, and I think I am stuck.

I have prepared the lincs_trapnell parquet file using your script in the embeddings folder and I was able to manually load the relevant config and the relevant pretrained model using this piece of code:

root_dir = '/myvolume/chemCPA'
model_hash = '27b401db1845eea26c102fb614df9c33'

print('Loading config json...')
with open(f'{root_dir}/config.json','r') as f:
        file_data = json.load(f)

for _config in tqdm(file_data):
    if _config["config_hash"] == model_hash:
        # print(config)
        config = _config["config"]
        config["config_hash"] = _config["config_hash"]

smiles_list = pd.read_csv(f'{root_dir}/embeddings/lincs_trapnell.smiles')['smiles'].tolist()

config["model"]["embedding"]["directory"] = f'{root_dir}/embeddings'
config["pretrained_model_path"] = f'{root_dir}/notebooks/chemCPA_models'

print('Loading model...')
model_pretrained_rdkit, embedding_pretrained_rdkit = load_model(config, smiles_list)

Now I browsed through the notebooks and I know that I need to use this function:

drug_r2_pretrained_degs_rdkit, _ = compute_pred(model_pretrained_rdkit, 
                                     datasets['ood'], 
                                     genes_control=datasets['test_control'].genes, 
                                     dosages=dosages,
                                     cell_lines=cell_lines,
                                     use_DEGs=True,
                                     verbose=False,
                                    )

The problem is that I don't have the dataset you use in the notebooks to see what the proper structure of the molecule list in datasets['ood'] should be and the same for genes_control. Dosages and cell_lines I have from the notebook itself.

Could you please either provide a sample dataset to get familiar with these lists or tell me what these need to be? I am not sure if I should use SMILES for ['odd'], or do I have to embed them first using one of the scripts and use the embeddings here? I want to get a pretty simple prediction here before I go deeper, and that's why I'm not using seml, but a little bit of hacked code from your scripts. I hope that makes sense, thanks!

MxMstrmn commented 1 year ago

Hi @michal-pikusa,

Thanks for reaching out! I am not 100% confident that I understood the exact problem, but this might be helping:

This is how we load the adata from a config:

def load_dataset(config):
    perturbation_key = config["dataset"]["data_params"]["perturbation_key"]
    smiles_key = config["dataset"]["data_params"]["smiles_key"]
    dataset = sc.read(config["dataset"]["data_params"]["dataset_path"])
    key_dict = {
        "perturbation_key": perturbation_key,
        "smiles_key": smiles_key,
    }
    return dataset, key_dict

And the corresponding config['dataset']['data_params'] can look like this:

{'covariate_keys': 'cell_type',
 'dataset_path': 'path_to_adata.h5ad'),
 'degs_key': 'rank_genes_groups_cov',
 'dose_key': 'dose',
 'pert_category': 'cov_drug_dose_name',
 'perturbation_key': 'condition',
 'smiles_key': 'SMILES',
 'split_key': 'split_key',
 'use_drugs_idx': True}

The split key defines which observations are in ood. We somewhat followed an unusual naming where we have train, test, and ood in each split.

The torch dataset is then loaded as follows:

data_params = config['dataset']['data_params']
datasets = load_dataset_splits(**data_params, return_dataset=False)

where the method is imported with from chemCPA.data import load_dataset_splits. The method for computing predictions refers to such torch datasets.

michal-pikusa commented 1 year ago

Thank you @MxMstrmn for your quick response. I'll try to be more precise because my case might be completely naive and I might have wrong assumptions to begin with.

So what I would like to do is to take a pretrained model from your repo and make fold-change predictions on L1000 genes on a completely novel set of SMILES that I have. It's only SMILES without anything else, hence no anndata dataset from me to begin with.

Is it possible? This is the part that I struggle with, since the notebooks show the use case of training/finetuning the model on the existing dataset and using ood for predictions, but not doing predictions on novel compounds exclusively - hence my question. Looking at the code I see that the predictions are made on the torch datasets, so the SMILES for predictions have to be embedded first. So what I am looking for really is the the clear path from loading a pretrained model, embedding novel SMILES without any additional meta data, and doing predictions on these.

If the code is not suitable for this use case, then I will just train a model from scratch, but I wanted to know first.

Thanks again for your help!

siboehm commented 1 year ago

I'm not familiar with fold-change prediction, but assuming it just takes our predicted gene expression as input, this should be possible. The datasets we deal with have the SMILES encoded as plain strings, we call the embedding models here, using just the SMILES ASCII string (so no metadata / concrete dataset possible).

It's probably most useful for you to look at how we do the evaluation. Some experiments in the paper involve us running eval on drugs that haven't been observed before and getting their predicted gene expression, which sounds like exactly your usecase.