jyaacoub / MutDTA

Improving the precision oncology pipeline by providing binding affinity purtubations predictions on a pirori identified cancer driver genes.
1 stars 2 forks source link

Use RING features #77

Closed jyaacoub closed 6 months ago

jyaacoub commented 8 months ago

For this I will need to use a different GNN that accepts edge_attr like transformerConv below

Code

# %%
import torch
from torch_geometric.nn import TransformerConv
from torch_geometric import data as geo_data
from src.models.ring_mod import Ring3DTA

#%%

device = torch.cuda.current_device()

# NOTE setting edge_dim will require edge_attr to be passed in!
# NOTE out_channels*heads is the final output dim -> [N_nodes, out_channels*heads]
#      Unless concat is set to false (gets averaged instead) -> [N_nodes, out_channels]
pro_gnn1 = TransformerConv(in_channels=320, out_channels=512, heads=5, 
                           concat=False,
                           dropout=0.2).to(device)

N_nodes = 100
prot_shape = (N_nodes, pro_gnn1.in_channels)
prot = geo_data.Data(x=torch.Tensor(*prot_shape), # node feature matrix
                    edge_index=torch.LongTensor([[0,1], [1,0]]).transpose(1, 0),
                    y=torch.FloatTensor([1])).to(device)
#%% 
# without edge attr this returns a tensor of shape (N_nodes, out_channels)
out = pro_gnn1(x=prot.x, edge_index=prot.edge_index)
out.shape

#%% with edge attr
pro_gnn2 = TransformerConv(pro_gnn1.out_channels, 1024, edge_dim=4, dropout=0.2).to(device)
edge_attr = torch.rand((prot.edge_index.shape[1], pro_gnn2.edge_dim)).to(device)
# edge_attr will be of shape (num_edges, edge_dim)

out2 = pro_gnn2(x=out, edge_index=prot.edge_index, edge_attr=edge_attr)
out2

#%%
import matplotlib.pyplot as plt
plt.matshow(out2.cpu().detach().numpy())

# remove axes
plt.xticks([])
plt.yticks([])
jyaacoub commented 8 months ago

Ring edge features are labeled as the following:

array(['HBOND:SC_MC', 'VDW:SC_SC', 'HBOND:MC_MC', 'VDW:MC_SC',
       'HBOND:MC_SC', 'PIPISTACK:SC_SC', 'VDW:SC_MC', 'PICATION:SC_SC',
       'HBOND:SC_SC', 'IONIC:SC_SC', 'VDW:MC_MC']

Note that SC_SC means sidechain-sidechain and MC_MC means mainchain-mainchain and so on for MC_SC and SC_MC

Translating this to the main features we want to use:

- H-Bond            -> HBOND
- Pi-Pi stack       -> PIPISTACK
- Pi-Cation         -> PICATION
- Ionic             -> IONIC
- Van der Waals     -> VDW
- Pi-H bond         -> ????

- Contact distance  -> From raw PDB file (see: 'simple' edge features)

And to get the frequencies across multiple models/confirmations we can add the --md flag to the end of the ring command:

RING-MD

RING - Molecular Dynamics (MD) is a RING module that computes aggregated statistics over multi-state structure files like NMR structural ensembles or molecular dynamics snapshots (provided as different models in a PDB/mmCIF file).

The module can be executed by adding the --md flag:

ring -i 2m6z.cif --out_dir results --md

It generates the standard output files inside the results/ folder and creates a md subdirectory with four different types of data:

Where <type> is the type of interaction: HBOND, IAC, IONIC, PICATION, PIPISTACK, SSBOND, VDW


We are interested in <fileName>_gfreq_<type>

jyaacoub commented 7 months ago

Some plots for the 6 main features we want to extract:

Recall the features we want:

- H-Bond            -> HBOND
- Pi-Pi stack       -> PIPISTACK
- Pi-Cation         -> PICATION
- Ionic             -> IONIC
- Van der Waals     -> VDW
- Contact distance  -> From raw PDB file (see: 'simple' edge features)

image image

Code:

#%%
from glob import glob
from pathlib import Path
from src.utils.residue import Ring3Runner
import os
import logging
logging.getLogger().setLevel(logging.INFO)

# %%
pdb_7lqt = f'{Path.home()}/projects/data/misc/7LQT.pdb'

af_conf_dir = f'{Path.home()}/projects/data/misc/'
af_confs_EGFR = glob(f'{af_conf_dir}/EGFR*/EGFR_unrelaxed_rank_*.pdb')

#%%
from src.utils.residue import Chain
import matplotlib.pyplot as plt
import numpy as np
opt = af_confs_EGFR
opt = pdb_7lqt
thr = 8.0

for opt in [af_confs_EGFR, pdb_7lqt]:
    # get distance contact map
    if opt is af_confs_EGFR:
        chains = [Chain(p) for p in opt]
        M = np.array([c.get_contact_map() for c in chains]) < thr
        dist_cmap = np.sum(M, axis=0) / len(M)
    else:
        dist_cmap = Chain(opt).get_contact_map() < thr

    # ring3 edge attribute extraction
    # Note: this will create a "combined" pdb file in the same directory as the confirmaions
    input_pdb, files = Ring3Runner.run(opt, overwrite=True)
    seq_len = len(Chain(input_pdb))

    # Converts output files into LxLx6 matrix for the 6 ring3 edge attributes
    r3_cmaps = []
    for k, fp in files.items():
        cmap = Ring3Runner.build_cmap(fp, seq_len)
        r3_cmaps.append(cmap)

    # COMBINE convert to numpy array
    # plot all 6 cmaps
    fig, axs = plt.subplots(2,3, figsize=(15,10))

    ks = list(files.keys()) + ['dist']
    for i, cmap in enumerate(r3_cmaps + [dist_cmap]):
        ax = axs[i//3, i%3]
        ax.matshow(cmap)
        ax.set_title(ks[i])

    plt.suptitle(f'Ring3 Edge Attributes for {"EGFR" if opt is af_confs_EGFR else "7LQT"}')
    plt.show()
jyaacoub commented 7 months ago

Sample test works well:


# %%
from pathlib import Path
from glob import glob
from src.utils.residue import Chain
from src.data_prep.feature_extraction.protein_edges import get_target_edge_weights
from src import config as cfg
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import TransformerConv
import torch

af_conf_dir = f'{Path.home()}/projects/data/misc/'
af_confs_EGFR = glob(f'{af_conf_dir}/EGFR*/EGFR_unrelaxed_rank_*00.pdb')

target = Chain(af_confs_EGFR[0])
L = len(target)
x = torch.rand(L, 2, dtype=torch.float32) # [N, feat_dim]

# %% get edge information
dist_cmap = target.get_contact_map() < 8.0
ei = torch.tensor(np.tril(dist_cmap)).nonzero().T # [2, E]

ea = get_target_edge_weights('', target.sequence, 
                             edge_opt=cfg.EDGE_OPT.ring3.value,
                             af_confs=af_confs_EGFR)
# using only the first cmap to determine which edge values to use
ea = torch.Tensor(ea[ei[0], ei[1], :]) # [E, 6]

sample_data = Data(x=x, edge_index=ei,edge_attr=ea)

#%%
model = TransformerConv(in_channels=2, out_channels=2, heads=1, 
                        edge_dim=6) # 6 edge attributes
model(sample_data.x, sample_data.edge_index, sample_data.edge_attr).shape
jyaacoub commented 6 months ago

For AlphaFlow confirmations we can run ring3 as they are being generated:

#%%
from src.data_prep.datasets import BaseDataset
import pandas as pd
csv_p = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_ring3_original_binary/full/XY.csv"

df = pd.read_csv(csv_p, index_col=0)
df_unique = BaseDataset.get_unique_prots(df)

# %%
import os
from tqdm import tqdm
alphaflow_dir = "/cluster/home/t122995uhn/projects/data/pdbbind/alphaflow_io/out_pdb_MD-distilled/"
ln_dir =        "/cluster/home/t122995uhn/projects/data/pdbbind/alphaflow_io/out_pid_ln/"

os.makedirs(ln_dir, exist_ok=True)
# files are .pdb with 50 "models" in each
for file in tqdm(os.listdir(alphaflow_dir)):
    if not file.endswith('.pdb'):
        continue

    code, _ = os.path.splitext(file)
    pid = df_unique.loc[code].prot_id

    os.symlink(f"{alphaflow_dir}/{file}", 
               f"{ln_dir}/{pid}.pdb")

# %% RUN RING3
# %% Run RING3 on finished confirmations from AlphaFlow
from src.utils.residue import Ring3Runner

files = [os.path.join(ln_dir, f) for f in \
            os.listdir(ln_dir) if f.endswith('.pdb')]

Ring3Runner.run_multiprocess(pdb_fps=files)