jyaacoub / MutDTA

Improving the precision oncology pipeline by providing binding affinity purtubations predictions on a pirori identified cancer driver genes.
1 stars 2 forks source link

Explore MISATO dataset #76

Closed jyaacoub closed 8 months ago

jyaacoub commented 8 months ago

There are some issuse with the MISATO dataset that should be resolved before I can readily use it (see https://github.com/t7morgen/misato-dataset/issues/7).

CODE

# %% Loading misato-dataset
import pickle
import pandas as pd

from misato_dataset import MolDataset, ProtDataset, GNNTransformMD, GNNTransformQM
from torch_geometric.loader import DataLoader
from torch_geometric import transforms as T

from src.feature_extraction.protein import Chain

HOME = '~'

misato_dir = f'{HOME}/projects/data/MISATO/'

MD_fp = f"{misato_dir}/MD.hdf5"
QM_fp = f"{misato_dir}/QM.hdf5"
normQM_fp = f"{HOME}/projects/misato-dataset/data/QM/h5_files/qm_norm.hdf5"

# file paths for train, test, val splits which were done by splitting on protein sequence similiarity
train = f'{misato_dir}/train_MD.txt'
val = f'{misato_dir}/train_MD.txt'
test = f'{misato_dir}/train_MD.txt'

# %%
transform = T.RandomJitter(0.25)
batch_size = 128
num_workers = 48

mol_train = MolDataset(QM_fp, train, target_norm_file=normQM_fp, transform=GNNTransformQM(), post_transform=transform)
pro_test = ProtDataset(MD_fp, test, transform=GNNTransformMD())

#%% sequence info is in the protein h5f file...
p_id = '2G6P'.lower()
sample = pro_test.f[p_id.upper()]
seq = list(sample['atoms_residue'])
print("CA count:", list(sample['atoms_type']).count(6)) # count of CA

# need to map it to string using map file
p_dir = "{HOME}/projects/misato-dataset/src/misato_dataset/processing/Maps/"
p = f"{p_dir}/atoms_residue_map.pickle"
itores = pickle.load(open(p, 'rb'))
p = f"{p_dir}/atoms_type_map.pickle"
ito_atmtype = pickle.load(open(p, 'rb'))

# for i in seq:
#     print(itores[i], end="")

df = pd.read_csv(f'{HOME}/projects/data/PDBbindDataset/nomsa_binary_original_binary/full/XY.csv')
# seq = df[df.code == p.id[0].lower()].prot_seq.iloc[0]
# print(len(seq), seq)

c = Chain(f"{HOME}/projects/data/v2020-other-PL/{p_id}/{p_id}_protein.pdb")
print(len(c), c)
c = Chain(f"{HOME}/projects/data/v2020-other-PL/{p_id}/{p_id}_pocket.pdb")
print(len(c), c)

#%%
mol_train_loader = DataLoader(mol_train, 125, shuffle=True, num_workers=0)
pro_test_loader = DataLoader(mol_train, 1, shuffle=True, num_workers=0)

#%%
for p in pro_test_loader: break

print(p, p.id)

#%%
for idx, val in enumerate(mol_train_loader):
    print(val)
    break

# %%
from src.utils import config as cfg
from src.utils.loader import Loader

model = Loader.init_test_model() 

# %%
jyaacoub commented 8 months ago

They use the entire pdb file for amber MD simulation.

Creation of the graph data objects is done with the GNNTransformMD class which uses prot_graph_transform to do so:

image

This results in a Data object, that looks like this:

DataBatch(x=[3262, 11], edge_index=[2, 52466], edge_attr=[52466], y=[3262], pos=[3262, 3], ids=[1], batch=[3262], ptr=[2])

Where X is our nodes, 3262 is the number of atoms and 11 for the one hot vector for different atom types. Edge index is the same, edge_attr is a simple reciprocal Euclidian distance metric to acts as an edge weight representing the distances between nodes. The further they are the less the weight.

image image

Code:

import pickle
import pandas as pd

from misato_dataset import MolDataset, ProtDataset, GNNTransformMD, GNNTransformQM
from torch_geometric.loader import DataLoader
from torch_geometric import transforms as T

from src.data_prep.feature_extraction.protein import Chain
from pathlib import Path
import pickle

HOME = Path.home()

misato_dir = f'{HOME}/projects/data/MISATO/'
misato_dir_tiny = f"{HOME}/projects/misato-dataset/data"
map_dir = f"{HOME}/projects/misato-dataset/src/misato_dataset/processing/Maps/"
atm_name = pickle.load(open(f"{map_dir}/atoms_name_map_for_pdb.pickle", 'rb'))
i_to_type = pickle.load(open(f"{map_dir}/atoms_type_map.pickle", 'rb'))
i_to_res = pickle.load(open(f"{map_dir}/atoms_residue_map.pickle", 'rb'))

mdh5_file = f'{misato_dir_tiny}/MD/h5_files/tiny_md_out.hdf5'
train_idx = f"{misato_dir_tiny}/MD/splits/train_tinyMD.txt"

val_idx = f"{misato_dir_tiny}/MD/splits/val_tinyMD.txt"
test_idx = f"{misato_dir_tiny}/MD/splits/test_tinyMD.txt"

p_id = '2G6P'.lower()
p_id = '10gs'.lower()

# %% LOAD PROTEIN DATA:
MD_fp = f"{misato_dir}/MD.hdf5"
train = f'{misato_dir}/train_MD.txt'
pro_test_tiny = ProtDataset(mdh5_file, idx_file=train_idx, transform=GNNTransformMD(), post_transform=T.RandomJitter(0.05))
pro_test = ProtDataset(MD_fp, idx_file=train, transform=GNNTransformMD(), post_transform=T.RandomJitter(0.05))
sample = pro_test.f[p_id.upper()]

#%%
test_loader = DataLoader(pro_test_tiny, batch_size=1, num_workers=16)
for sample_loader in test_loader: break

#%%
seq = [i_to_res[i] for i in sample['atoms_residue']]
atom_type = [i_to_type[i] for i in sample['atoms_type']]  # convert atom_type from int to string representation

print("CA count:", atom_type.count('CX')) # count of CA

#%% GET REAL 
c = Chain(f"{HOME}/projects/data/v2020-other-PL/{p_id}/{p_id}_protein.pdb")
print(len(c), c)
c = Chain(f"{HOME}/projects/data/v2020-other-PL/{p_id}/{p_id}_pocket.pdb")
print(len(c), c)

#%%
from prody import parsePDB
p_id = '2G6P'.lower()
pdb = parsePDB(f"{HOME}/projects/data/v2020-other-PL/{p_id}/{p_id}_protein.pdb", subset='ca')
pdb
#%%