jyaacoub / MutDTA

Improving the precision oncology pipeline by providing binding affinity purtubations predictions on a pirori identified cancer driver genes.
1 stars 2 forks source link

Pocket representation #103

Closed jyaacoub closed 2 weeks ago

jyaacoub commented 4 months ago

Pocket-only representation

To make sure we dont have to build entirely seperate datasets for the pocket representation, this implementation should just get index positions for our binding pocket and then just apply a mask to the original graph (similar to how it is done with the dropout_node function in pytorch_geometric).

Task list:

KLIFS Database

This is used by KBDNet to get the binding pockets for davis and kiba "pocket is 85 residues long".

The sequence given from KLIFS is not contiguous and only contains relevant pocket residues.

After we get our list of index positions for the binding pocket AA we can modify our existing graph by applying a mask

Getting pockets for Kiba

  1. Using UniProt ID we can get the pocket from the KLIFS database with the /kinase_ID API.
    1. For example: https://klifs.net/api/kinase_ID?kinase_name=O00141&species=HUMAN returns: image

Getting pockets for davis:

Same as for kiba, but we use the raw Gene Name code (need to remove any mutation or phosphorylation information): ABL1(F317I)p -> ABL1

  1. For example: https://klifs.net/api/kinase_ID?kinase_name=ABL1&species=HUMAN returns: image

However for mutated genes we must be careful with the sequence alignment, and must follow the following procedure to get the right amino acid index positions:

  1. Reverting the mutation
  2. Perform alignment
  3. Extract index positions

then we just use these positions for our mask on the original (mutated) graph.

jyaacoub commented 1 month ago

Building the pocket dataset

Assumes that we have a normal dataset built already.

1. get and mask for pockets with KLIFS

This should be done first on the login node since it queries the KLIFS database for the sequences and caches them locally. Then we can use the skip_download arg to run this on the compute node for the rest of the datasets.

# building pocket datasets:
from src.utils.pocket_alignment import pocket_dataset_full
import shutil
import os

data_dir = '/cluster/home/t122995uhn/projects/data/'
db_type = ['kiba', 'davis']
db_feat = ['nomsa_binary_original_binary', 'nomsa_aflow_original_binary', 
           'nomsa_binary_gvp_binary',      'nomsa_aflow_gvp_binary']

for t in db_type:
    for f in db_feat:
        print(f'\n---{t}-{f}---\n')
        dataset_dir= f"{data_dir}/DavisKibaDataset/{t}/{f}/full"
        save_dir   = f"{data_dir}/v131/DavisKibaDataset/{t}/{f}/full"

        pocket_dataset_full(
            dataset_dir= dataset_dir,
            pocket_dir = f"{data_dir}/{t}/",
            save_dir   = save_dir,
            skip_download=True
        )

2. resplit the database:

import os
from src.data_prep.init_dataset import create_datasets
from src import cfg
import logging
cfg.logger.setLevel(logging.DEBUG)

dbs = [cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba]
splits = ['davis', 'kiba']
splits = ['/cluster/home/t122995uhn/projects/MutDTA/splits/' + s for s in splits]
print(splits)

#%%
for split, db in zip(splits, dbs):
    print('\n',split, db)
    create_datasets(db, 
                feat_opt=cfg.PRO_FEAT_OPT.nomsa, 
                edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow],
                ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp], 
                ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False,
                k_folds=5,
                test_prots_csv=f'{split}/test.csv',
                val_prots_csv=[f'{split}/val{i}.csv' for i in range(5)],)
                # data_root=os.path.abspath('../data/test/'))

3. test inference

#%%
from src import cfg
from src.utils.loader import Loader

# db2 = Loader.load_dataset(cfg.DATA_OPT.davis, 
#                          cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
#                          path='/cluster/home/t122995uhn/projects/data/',
#                          subset="full")

db2 = Loader.load_DataLoaders(cfg.DATA_OPT.davis, 
                         cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
                         path='/cluster/home/t122995uhn/projects/data/v131',
                         training_fold=0,
                         batch_train=2)
for b2 in db2['train']: break

# %%
m = Loader.init_model(cfg.MODEL_OPT.DG, cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
                  dropout=0.3480, output_dim=256,
                  )

#%%
# m(b['protein'], b['ligand'])
m(b2['protein'], b2['ligand'])