Closed jyaacoub closed 1 month ago
Removing all exact instances of pdbids leaves us with 975 rows (967 if we drop nan):
However, if we consider both mutated and wildtype proteins as the same protein then we are left with 480 rows.
with overlap | without overlap | p-val | |
---|---|---|---|
cindex | 0.674 $\pm$ 0.010 | 0.641 $\pm$ 0.011 | 0.0624 |
pcorr | 0.530 $\pm$ 0.023 | 0.415 $\pm$ 0.028 | 0.0139 |
scorr | 0.503 $\pm$ 0.030 | 0.411 $\pm$ 0.039 | 0.0973 |
mse | 3.106 $\pm$ 0.121 | 3.467 $\pm$ 0.112 | 0.0598 |
mae | 1.380 $\pm$ 0.026 | 1.467 $\pm$ 0.014 | 0.0198 |
rmse | 1.761 $\pm$ 0.035 | 1.861 $\pm$ 0.030 | 0.062 |
![image](https://github.com/jyaacoub/MutDTA/assets/50300488/13f95075-5409-405c-9d2e-94599a8cc0e3)
with overlap | without overlap | p-val | |
---|---|---|---|
cindex | 0.674 $\pm$ 0.010 | 0.641 $\pm$ 0.011 | 0.0624 |
pcorr | 0.530 $\pm$ 0.023 | 0.415 $\pm$ 0.028 | 0.0139 |
scorr | 0.503 $\pm$ 0.030 | 0.411 $\pm$ 0.039 | 0.0973 |
mse | 0.939 $\pm$ 0.046 | 0.947 $\pm$ 0.034 | 0.9008 |
mae | 0.758 $\pm$ 0.020 | 0.754 $\pm$ 0.009 | 0.8436 |
rmse | 0.968 $\pm$ 0.024 | 0.972 $\pm$ 0.018 | 0.8879 |
![image](https://github.com/jyaacoub/MutDTA/assets/50300488/d1b3a145-f259-4386-818a-6fb7562c7b7a)
Same thing but looking at $\Delta pkd$ this time
![image](https://github.com/jyaacoub/MutDTA/assets/50300488/e06bc2c4-84e0-4fc2-af09-855a366ecc9d)
metric | With Overlap | Without Overlap | Significance |
---|---|---|---|
pcorr | 0.176 $\pm$ 0.026 | 0.037 $\pm$ 0.079 | * |
scorr | 0.099 $\pm$ 0.019 | 0.046 $\pm$ 0.060 | |
mse | 1.505 $\pm$ 0.009 | 1.303 $\pm$ 0.006 | * |
mae | 0.905 $\pm$ 0.003 | 0.847 $\pm$ 0.002 | * |
rmse | 1.227 $\pm$ 0.004 | 1.141 $\pm$ 0.003 | * |
![image](https://github.com/jyaacoub/MutDTA/assets/50300488/7b0d8675-cf9b-4504-a1fa-6e22f2896c91)
With Overlap | Without Overlap | Significance | |
---|---|---|---|
pcorr | 0.176 $\pm$ 0.026 | 0.037 $\pm$ 0.079 | * |
scorr | 0.099 $\pm$ 0.019 | 0.046 $\pm$ 0.060 | |
mse | 1.649 $\pm$ 0.053 | 1.927 $\pm$ 0.158 | * |
mae | 0.899 $\pm$ 0.029 | 1.014 $\pm$ 0.023 | * |
rmse | 1.284 $\pm$ 0.021 | 1.387 $\pm$ 0.057 | * |
1 mutations | 2+ mutations | Sig | |
---|---|---|---|
pcorr | 0.076 $\pm$ 0.019 | 0.336 $\pm$ 0.047 | * |
scorr | 0.053 $\pm$ 0.015 | 0.252 $\pm$ 0.043 | * |
mse | 1.848 $\pm$ 0.038 | 1.328 $\pm$ 0.095 | * |
mae | 0.961 $\pm$ 0.031 | 0.833 $\pm$ 0.028 | * |
rmse | 1.359 $\pm$ 0.014 | 1.152 $\pm$ 0.041 | * |
1 mutations | 2 mutations | 3+ mutations | |
---|---|---|---|
pcorr | 0.076 $\pm$ 0.019 | 0.207 $\pm$ 0.055 | 0.509 $\pm$ 0.070 |
scorr | 0.053 $\pm$ 0.015 | 0.131 $\pm$ 0.038 | 0.496 $\pm$ 0.078 |
mse | 1.848 $\pm$ 0.038 | 1.586 $\pm$ 0.110 | 0.982 $\pm$ 0.140 |
mae | 0.961 $\pm$ 0.031 | 0.964 $\pm$ 0.016 | 0.732 $\pm$ 0.022 |
rmse | 1.359 $\pm$ 0.014 | 1.259 $\pm$ 0.043 | 0.989 $\pm$ 0.070 |
mutation in pocket | mutation NOT in pocket | Sig | |
---|---|---|---|
pcorr | 0.180 $\pm$ 0.029 | 0.110 $\pm$ 0.056 | * |
scorr | 0.108 $\pm$ 0.022 | 0.050 $\pm$ 0.102 | |
mse | 1.640 $\pm$ 0.057 | 1.781 $\pm$ 0.113 | * |
mae | 0.904 $\pm$ 0.021 | 0.981 $\pm$ 0.018 | * |
rmse | 1.280 $\pm$ 0.022 | 1.334 $\pm$ 0.042 | * |
```python # %% import matplotlib.pyplot as plt import seaborn as sns from src.analysis.figures import get_dpkd NORMALIZE = True dfr = pd.read_csv(f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv", index_col=0) df = pd.read_csv("/cluster/home/t122995uhn/projects/data/PlatinumDataset/nomsa_binary_original_binary/full/cleaned_XY.csv", index_col=0).dropna() df = get_in_binding(df, dfr=dfr) fig = plt.figure(figsize=(14,5)) ax = fig.subplots(1,1) # must include 0 in both cases since they are the wildtype reference true_dpkd1 = get_dpkd(df.query('(pocket == 0) | (pocket == 2)'), 'pkd', NORMALIZE) sns.histplot(true_dpkd1, kde=True, ax=ax, alpha=0.6, color='orange', label='not in pocket', stat='proportion') true_dpkd1 = get_dpkd(df.query('(pocket == 0) | (pocket == 1)'), 'pkd', NORMALIZE) sns.histplot(true_dpkd1, kde=True, ax=ax, alpha=0.6, color=None, label='in pocket', stat='proportion') ax.set_title(f"{'Normalized 'if NORMALIZE else ''}TRUE Δpkd distribution") ax.set_xlabel('Δpkd') ax.legend() ```
counts | |
---|---|
wt | 981 |
near lig | 577 |
not near lig | 372 |
mutation near lig (<4A) | mutation not near lig (>4A) | p-val | |
---|---|---|---|
pcorr | 0.164 $\pm$ 0.012 | 0.198 $\pm$ 0.017 | 0.1405 |
scorr | 0.104 $\pm$ 0.009 | 0.079 $\pm$ 0.015 | 0.1844 |
mse | 1.672 $\pm$ 0.023 | 1.604 $\pm$ 0.034 | 0.1405 |
mae | 0.912 $\pm$ 0.010 | 0.939 $\pm$ 0.005 | 0.0408 |
rmse | 1.293 $\pm$ 0.009 | 1.266 $\pm$ 0.013 | 0.1381 |
![image](https://github.com/jyaacoub/MutDTA/assets/50300488/9e82d7f7-f3f2-4215-8dde-2de126c7f2c3)
```python
from src.analysis.figures import tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_n_mut, tbl_dpkd_metrics_in_binding, predictive_performance, tbl_stratified_dpkd_metrics
from src.analysis.metrics import get_metrics
from src import config as cfg
import pandas as pd
#%%
MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv"
RAW_PLT_CSV=f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv"
NORMALIZE = True
n_models=5
verbose=True
plot=True
dfr = pd.read_csv(RAW_PLT_CSV, index_col=0)
dfp = pd.read_csv(MODEL(0), index_col=0)
#%%
import seaborn as sns
sns.histplot(dfr['mut.distance_to_lig'])
#%%
# add in_binding info to df
thres = 4
def get_in_binding(df, dfr):
"""
df is the predicted csv with index as
![image](https://github.com/jyaacoub/MutDTA/assets/50300488/48acc8ff-e718-4d56-afa4-4b9af8b2d138) ![image](https://github.com/jyaacoub/MutDTA/assets/50300488/7b8e0fa3-fce9-41ca-84ce-5f47d7cbe9d3)
![image](https://github.com/jyaacoub/MutDTA/assets/50300488/35c9b63a-b6b8-4cda-882a-89bb25a85288) ![image](https://github.com/jyaacoub/MutDTA/assets/50300488/8e20e21e-4dec-452d-9304-4cc4cecce72b)
```python # %% import torch, os import pandas as pd from src import cfg from src import TUNED_MODEL_CONFIGS from src.utils.loader import Loader from src.train_test.training import test from src.analysis.figures import predictive_performance, tbl_stratified_dpkd_metrics, tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_in_binding device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") INFERENCE = False VERBOSE = True out_dir = f'{cfg.MEDIA_SAVE_DIR}/test_set_pred/' os.makedirs(out_dir, exist_ok=True) cp_dir = cfg.CHECKPOINT_SAVE_DIR RAW_PLT_CSV=f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv" #%% load up model: for KEY, CONFIG in TUNED_MODEL_CONFIGS.items(): MODEL_KEY = lambda fold: Loader.get_model_key(CONFIG['model'], CONFIG['dataset'], CONFIG['feature_opt'], CONFIG['edge_opt'], CONFIG['batch_size'], CONFIG['lr'], CONFIG['architecture_kwargs']['dropout'], n_epochs=2000, fold=fold, ligand_feature=CONFIG['lig_feat_opt'], ligand_edge=CONFIG['lig_edge_opt']) print('\n\n'+ '## ' + KEY) OUT_PLT = lambda i: f'{out_dir}/{MODEL_KEY(i)}_PLATINUM.csv' db_p = f"{CONFIG['feature_opt']}_{CONFIG['edge_opt']}_{CONFIG['lig_feat_opt']}_{CONFIG['lig_edge_opt']}" if CONFIG['dataset'] in ['kiba', 'davis']: db_p = f"DavisKibaDataset/{CONFIG['dataset']}/{db_p}" else: db_p = f"{CONFIG['dataset']}Dataset/{db_p}" train_p = lambda set: f"{cfg.DATA_ROOT}/{db_p}/{set}0/cleaned_XY.csv" if not os.path.exists(OUT_PLT(0)) or INFERENCE: print('running inference!') cp = lambda fold: f"{cp_dir}/{MODEL_KEY(fold)}.model" model = Loader.init_model(model=CONFIG["model"], pro_feature=CONFIG["feature_opt"], pro_edge=CONFIG["edge_opt"],**CONFIG['architecture_kwargs']) # load up platinum test db loaders = Loader.load_DataLoaders(cfg.DATA_OPT.platinum, pro_feature = CONFIG['feature_opt'], edge_opt = CONFIG['edge_opt'], ligand_feature = CONFIG['lig_feat_opt'], ligand_edge = CONFIG['lig_edge_opt'], datasets=['test']) for i in range(5): model.safe_load_state_dict(torch.load(cp(i), map_location=device)) model.to(device) model.eval() loss, pred, actual = test(model, loaders['test'], device, verbose=True) # saving as csv with columns code, pred, actual # get codes from test loader codes, pid = [b['code'][0] for b in loaders['test']], [b['prot_id'][0] for b in loaders['test']] df = pd.DataFrame({'prot_id': pid, 'pred': pred, 'actual': actual}, index=codes) df.index.name = 'code' df.to_csv(OUT_PLT(i)) # run platinum eval: print('\n### 1. predictive performance') mkdown = predictive_performance(OUT_PLT, train_p, verbose=VERBOSE, plot=False) print('\n### 2 Mutation impact analysis') print('\n#### 2.1 $\Delta pkd$ predictive performance') mkdn = tbl_dpkd_metrics_overlap(OUT_PLT, train_p, verbose=VERBOSE, plot=False) print('\n#### 2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)') m = tbl_dpkd_metrics_in_binding(OUT_PLT, RAW_PLT_CSV, verbose=VERBOSE, plot=False) ```
mean predictive performance | |
---|---|
cindex | 0.472 $\pm$ 0.026 |
pcorr | -0.062 $\pm$ 0.050 |
scorr | -0.082 $\pm$ 0.078 |
mse | 2.123 $\pm$ 0.099 |
mae | 1.122 $\pm$ 0.049 |
rmse | 1.455 $\pm$ 0.034 |
with overlap | without overlap | p-val | |
---|---|---|---|
pcorr | 0.005 $\pm$ 0.012 | 0.005 $\pm$ 0.012 | 1 |
scorr | 0.003 $\pm$ 0.025 | 0.003 $\pm$ 0.025 | 1 |
mse | 1.990 $\pm$ 0.025 | 1.990 $\pm$ 0.025 | 1 |
mae | 0.980 $\pm$ 0.013 | 0.980 $\pm$ 0.013 | 1 |
rmse | 1.411 $\pm$ 0.009 | 1.411 $\pm$ 0.009 | 1 |
counts | |
---|---|
wt | 981 |
pckt | 708 |
not pckt | 241 |
mutation in pocket | mutation NOT in pocket | p-val | |
---|---|---|---|
pcorr | 0.004 $\pm$ 0.015 | 0.025 $\pm$ 0.032 | 0.5777 |
scorr | -0.009 $\pm$ 0.026 | 0.067 $\pm$ 0.071 | 0.3433 |
mse | 1.992 $\pm$ 0.030 | 1.950 $\pm$ 0.065 | 0.5777 |
mae | 1.000 $\pm$ 0.009 | 0.981 $\pm$ 0.021 | 0.4374 |
rmse | 1.411 $\pm$ 0.011 | 1.396 $\pm$ 0.023 | 0.5578 |
mean predictive performance | |
---|---|
cindex | 0.469 $\pm$ 0.007 |
pcorr | -0.058 $\pm$ 0.020 |
scorr | -0.095 $\pm$ 0.021 |
mse | 2.114 $\pm$ 0.040 |
mae | 1.148 $\pm$ 0.017 |
rmse | 1.454 $\pm$ 0.014 |
with overlap | without overlap | p-val | |
---|---|---|---|
pcorr | 0.008 $\pm$ 0.014 | 0.008 $\pm$ 0.014 | 1 |
scorr | 0.023 $\pm$ 0.011 | 0.023 $\pm$ 0.011 | 1 |
mse | 1.983 $\pm$ 0.029 | 1.983 $\pm$ 0.029 | 1 |
mae | 0.974 $\pm$ 0.011 | 0.974 $\pm$ 0.011 | 1 |
rmse | 1.408 $\pm$ 0.010 | 1.408 $\pm$ 0.010 | 1 |
counts | |
---|---|
wt | 981 |
pckt | 725 |
not pckt | 256 |
mutation in pocket | mutation NOT in pocket | p-val | |
---|---|---|---|
pcorr | 0.012 $\pm$ 0.012 | -0.008 $\pm$ 0.052 | 0.7244 |
scorr | 0.014 $\pm$ 0.019 | 0.025 $\pm$ 0.024 | 0.7289 |
mse | 1.977 $\pm$ 0.024 | 2.016 $\pm$ 0.105 | 0.7244 |
mae | 0.985 $\pm$ 0.014 | 0.979 $\pm$ 0.023 | 0.8043 |
rmse | 1.406 $\pm$ 0.009 | 1.418 $\pm$ 0.037 | 0.7599 |
mean predictive performance | |
---|---|
cindex | 0.446 $\pm$ 0.018 |
pcorr | -0.127 $\pm$ 0.059 |
scorr | -0.172 $\pm$ 0.051 |
mse | 2.253 $\pm$ 0.117 |
mae | 1.236 $\pm$ 0.032 |
rmse | 1.499 $\pm$ 0.039 |
with overlap | without overlap | p-val | |
---|---|---|---|
pcorr | -0.007 $\pm$ 0.010 | -0.007 $\pm$ 0.010 | 1 |
scorr | -0.019 $\pm$ 0.015 | -0.019 $\pm$ 0.015 | 1 |
mse | 2.015 $\pm$ 0.019 | 2.015 $\pm$ 0.019 | 1 |
mae | 0.922 $\pm$ 0.024 | 0.922 $\pm$ 0.024 | 1 |
rmse | 1.419 $\pm$ 0.007 | 1.419 $\pm$ 0.007 | 1 |
counts | |
---|---|
wt | 981 |
pckt | 708 |
not pckt | 241 |
mutation in pocket | mutation NOT in pocket | p-val | |
---|---|---|---|
pcorr | -0.013 $\pm$ 0.009 | 0.031 $\pm$ 0.024 | 0.117 |
scorr | -0.039 $\pm$ 0.018 | 0.063 $\pm$ 0.030 | 0.0208 |
mse | 2.027 $\pm$ 0.018 | 1.937 $\pm$ 0.048 | 0.117 |
mae | 0.922 $\pm$ 0.023 | 0.951 $\pm$ 0.027 | 0.4358 |
rmse | 1.424 $\pm$ 0.006 | 1.391 $\pm$ 0.017 | 0.1167 |
mean predictive performance | |
---|---|
cindex | 0.641 $\pm$ 0.011 |
pcorr | 0.415 $\pm$ 0.028 |
scorr | 0.411 $\pm$ 0.039 |
mse | 0.947 $\pm$ 0.034 |
mae | 0.754 $\pm$ 0.009 |
rmse | 0.972 $\pm$ 0.018 |
with overlap | without overlap | p-val | |
---|---|---|---|
pcorr | 0.176 $\pm$ 0.012 | 0.037 $\pm$ 0.035 | 0.0058 |
scorr | 0.099 $\pm$ 0.009 | 0.046 $\pm$ 0.027 | 0.0974 |
mse | 1.649 $\pm$ 0.024 | 1.927 $\pm$ 0.071 | 0.0058 |
mae | 0.899 $\pm$ 0.013 | 1.014 $\pm$ 0.010 | 0.0001 |
rmse | 1.284 $\pm$ 0.009 | 1.387 $\pm$ 0.025 | 0.005 |
counts | |
---|---|
wt | 981 |
pckt | 708 |
not pckt | 241 |
mutation in pocket | mutation NOT in pocket | p-val | |
---|---|---|---|
pcorr | 0.180 $\pm$ 0.013 | 0.110 $\pm$ 0.025 | 0.0374 |
scorr | 0.108 $\pm$ 0.010 | 0.050 $\pm$ 0.046 | 0.2533 |
mse | 1.640 $\pm$ 0.026 | 1.781 $\pm$ 0.050 | 0.0374 |
mae | 0.904 $\pm$ 0.010 | 0.981 $\pm$ 0.008 | 0.0003 |
rmse | 1.280 $\pm$ 0.010 | 1.334 $\pm$ 0.019 | 0.0365 |
GVPLM_PDBbind1D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE
%%
model = Loader.init_model(cfg.MODEL_OPT.GVPL, cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow, dropout=0.02414, output_dim=256)
%%
cp_dir = "/cluster/home/t122995uhn/projects/MutDTA/results/model_checkpoints/ours" MODEL_KEY = lambda fold: f"GVPLM_PDBbind{fold}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE" cp = lambda fold: f"{cp_dir}/{MODEL_KEY(fold)}.model"
out_dir = f'{cfg.MEDIA_SAVE_DIR}/test_set_pred/' os.makedirs(out_dir, exist_ok=True)
for i in range(5): model.safe_load_state_dict(torch.load(cp(i), map_location=device)) model.to(device) model.eval()