jyaacoub / MutDTA

Improving the precision oncology pipeline by providing binding affinity purtubations predictions on a pirori identified cancer driver genes.
1 stars 2 forks source link

Platinum analysis #94

Open jyaacoub opened 2 months ago

jyaacoub commented 2 months ago
  1. Build dataset
    1. [x] Copy all aflow platinum conf files from #narval to #h4h ✅ 2024-04-29
    2. [x] Init new platinum dataset with new confs ✅ 2024-04-29
      create_datasets(cfg.DATA_OPT.platinum, 
              cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow, 
              ligand_features=cfg.LIG_FEAT_OPT.gvp,
              ligand_edges=cfg.LIG_EDGE_OPT.binary,
              k_folds=None, train_split=0, val_split=0)       
  2. Run inference
    1. [x] Find trained weights for GVPL-aflow model on pdbbind (all 5 for each split to use as an ensemble) ✅ 2024-04-29
      • Model key is GVPLM_PDBbind1D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE
      • They are all on #h4h
    2. [x] Save predictions of each model to a csv file ✅ 2024-04-29
      
      loaders = Loader.load_DataLoaders(cfg.DATA_OPT.platinum,
                             cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
                             ligand_feature=cfg.LIG_FEAT_OPT.gvp, ligand_edge=cfg.LIG_EDGE_OPT.binary,
                             datasets=['test'])

%%

model = Loader.init_model(cfg.MODEL_OPT.GVPL, cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow, dropout=0.02414, output_dim=256)

%%

cp_dir = "/cluster/home/t122995uhn/projects/MutDTA/results/model_checkpoints/ours" MODEL_KEY = lambda fold: f"GVPLM_PDBbind{fold}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE" cp = lambda fold: f"{cp_dir}/{MODEL_KEY(fold)}.model"

out_dir = f'{cfg.MEDIA_SAVE_DIR}/test_set_pred/' os.makedirs(out_dir, exist_ok=True)

for i in range(5): model.safe_load_state_dict(torch.load(cp(i), map_location=device)) model.to(device) model.eval()

loss, pred, actual = test(model, loaders['test'], device, verbose=True)

# saving as csv with columns code, pred, actual
# get codes from test loader
codes, pid = [b['code'][0] for b in loaders['test']], [b['prot_id'][0] for b in loaders['test']]
df = pd.DataFrame({'prot_id': pid, 'pred': pred, 'actual': actual}, index=codes)
df.index.name = 'code'
df.to_csv(f'{out_dir}/{MODEL_KEY(i)}_PLATINUM.csv')

4. Analysis  ^3fe12b
    1. [x] Check for overlap in pdbIDs and remove them 
        - provide a plot of with and without performance?
    2. how can I analyze the performance of the model in its ability to detect dangerous mutations?
        1. [x] Normal predictive performance analysis with cindex and MSE scores
            - Statistical t-test to determine how different the predicted and experimental distributions are from each other
        2. [x] Mutation impact analysis: 
            - calculate $\Delta pk_a$ (change in binding affinity for each protein-ligand pair). Then correlate the predicted $\Delta pk_a$ and the experimental $\Delta pk_a$ 
                - higher correlation indicates that the **model effectively captures the impact of mutations**
            - Split up scores depending on number of mutations for even more analysis (maybe it struggles with larger # of mutations)
        3. [x] Identifying "significant" mutations
            - Based on the distribution of $\Delta pk_a$ scores classify all above **[1 and 2] standard deviations to be "significant"**?
            - Then Build a confusion matrix to analyze the true positive rate and true negative rate of the model in identifying these mutations.
jyaacoub commented 2 months ago

There is huge overlap between the PDBbind training data and the platinum dataset unfortunately...

Removing all exact instances of pdbids leaves us with 975 rows (967 if we drop nan): image

However, if we consider both mutated and wildtype proteins as the same protein then we are left with 480 rows. image

code ```python import pandas as pd df = pd.read_csv("results/model_media/test_set_pred/GVPLM_PDBbind0D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv", index_col=0) # training set codes: data_p = lambda set: f'/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv' df_t = pd.Index.append(pd.read_csv(data_p('train'), index_col=0).index, pd.read_csv(data_p('val'), index_col=0).index) df_t = df_t.str.upper() df['pdb'] = df['prot_id'].str.split('_').str[0] #%% remove training codes from df # dont remove wt prots: # df = df[~(df['pdb'].isin(df_t) & df.index.str.contains('_mt'))] print(df) # remove all df = df[~(df['pdb'].isin(df_t))] print(df) # %% treat mutated and wt proteins as the same wt_df = df[df.index.str.contains("_wt")] mt_df = df[df.index.str.contains("_mt")] missing_wt = delta_pkds = 0 for m in mt_df.index: i_wt = m.split('_')[0] + '_wt' if i_wt not in wt_df.index: missing_wt += 1 else: delta_pkds += 1 print("missing wt:", missing_wt) print("delta_pkds:", delta_pkds) ```
jyaacoub commented 2 months ago

1. predictive performance

raw:

with overlap without overlap p-val
cindex 0.674 $\pm$ 0.010 0.641 $\pm$ 0.011 0.0624
pcorr 0.530 $\pm$ 0.023 0.415 $\pm$ 0.028 0.0139
scorr 0.503 $\pm$ 0.030 0.411 $\pm$ 0.039 0.0973
mse 3.106 $\pm$ 0.121 3.467 $\pm$ 0.112 0.0598
mae 1.380 $\pm$ 0.026 1.467 $\pm$ 0.014 0.0198
rmse 1.761 $\pm$ 0.035 1.861 $\pm$ 0.030 0.062
distribution

![image](https://github.com/jyaacoub/MutDTA/assets/50300488/13f95075-5409-405c-9d2e-94599a8cc0e3)

z-normalized:

with overlap without overlap p-val
cindex 0.674 $\pm$ 0.010 0.641 $\pm$ 0.011 0.0624
pcorr 0.530 $\pm$ 0.023 0.415 $\pm$ 0.028 0.0139
scorr 0.503 $\pm$ 0.030 0.411 $\pm$ 0.039 0.0973
mse 0.939 $\pm$ 0.046 0.947 $\pm$ 0.034 0.9008
mae 0.758 $\pm$ 0.020 0.754 $\pm$ 0.009 0.8436
rmse 0.968 $\pm$ 0.024 0.972 $\pm$ 0.018 0.8879
distribution

![image](https://github.com/jyaacoub/MutDTA/assets/50300488/d1b3a145-f259-4386-818a-6fb7562c7b7a)

code ```python def predictive_performance( MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv", TRAIN_DATA_P = lambda set: f'{cfg.DATA_ROOT}/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv', NORMALIZE = True, n_models=5, compare_overlap=False, verbose=True, plot=False, ): df_t = pd.Index.append(pd.read_csv(TRAIN_DATA_P('train'), index_col=0).index, pd.read_csv(TRAIN_DATA_P('val'), index_col=0).index) df_t = df_t.str.upper() results_with_overlap = [] results_without_overlap = [] for i in range(n_models): df = pd.read_csv(MODEL(i), index_col=0).dropna() df['pdb'] = df['prot_id'].str.split('_').str[0] if NORMALIZE: mean_df = df[['actual','pred']].mean(axis=0, numeric_only=True) std_df = df[['actual','pred']].std(axis=0, numeric_only=True) df[['actual','pred']] = (df[['actual','pred']] - mean_df) / std_df # z-normalization if i==0: print(df) # with overlap cindex, p_corr, s_corr, mse, mae, rmse = get_metrics(df['actual'], df['pred']) results_with_overlap.append([cindex, p_corr[0], s_corr[0], mse, mae, rmse]) # without overlap df_no_overlap = df[~(df['pdb'].isin(df_t))] cindex, p_corr, s_corr, mse, mae, rmse = get_metrics(df_no_overlap['actual'], df_no_overlap['pred']) results_without_overlap.append([cindex, p_corr[0], s_corr[0], mse, mae, rmse]) if i==0 and plot: n_plots = int(compare_overlap)+1 fig = plt.figure(figsize=(14,5*n_plots)) axes = fig.subplots(n_plots,1) ax = axes[0] if compare_overlap else axes sns.histplot(df_no_overlap['actual'], kde=True, ax=ax, alpha=0.5, label='True pkd') sns.histplot(df_no_overlap['pred'], kde=True, ax=ax, alpha=0.5, label='Predicted pkd', color='orange') ax.set_title(f"{'Normalized 'if NORMALIZE else ''} pkd distribution") ax.legend() if compare_overlap: sns.histplot(df_no_overlap['actual'], kde=True, ax=axes[1], alpha=0.5, label='True pkd') sns.histplot(df_no_overlap['pred'], kde=True, ax=axes[1], alpha=0.5, label='Predicted pkd', color='orange') axes[1].set_title(f"{'Normalized 'if NORMALIZE else ''} pkd distribution (no overlap)") axes[1].legend() if compare_overlap: return generate_markdown([results_with_overlap, results_without_overlap], names=['with overlap', 'without overlap'], cindex=True,verbose=verbose) return generate_markdown([results_without_overlap], names=['mean $\pm$ se'], cindex=True, verbose=verbose) ```
jyaacoub commented 2 months ago

2. Mutation impact analysis

Same thing but looking at $\Delta pkd$ this time

2.1 delta pkd predictive performance

raw $\Delta pkd$

Model-0 Distribution

![image](https://github.com/jyaacoub/MutDTA/assets/50300488/e06bc2c4-84e0-4fc2-af09-855a366ecc9d)

metric With Overlap Without Overlap Significance
pcorr 0.176 $\pm$ 0.026 0.037 $\pm$ 0.079 *
scorr 0.099 $\pm$ 0.019 0.046 $\pm$ 0.060
mse 1.505 $\pm$ 0.009 1.303 $\pm$ 0.006 *
mae 0.905 $\pm$ 0.003 0.847 $\pm$ 0.002 *
rmse 1.227 $\pm$ 0.004 1.141 $\pm$ 0.003 *

Z-normalized

Model-0 Distribution

![image](https://github.com/jyaacoub/MutDTA/assets/50300488/7b0d8675-cf9b-4504-a1fa-6e22f2896c91)

With Overlap Without Overlap Significance
pcorr 0.176 $\pm$ 0.026 0.037 $\pm$ 0.079 *
scorr 0.099 $\pm$ 0.019 0.046 $\pm$ 0.060
mse 1.649 $\pm$ 0.053 1.927 $\pm$ 0.158 *
mae 0.899 $\pm$ 0.029 1.014 $\pm$ 0.023 *
rmse 1.284 $\pm$ 0.021 1.387 $\pm$ 0.057 *
Code ```python from src.analysis.figures import tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_n_mut MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv" TRAIN_DATA_P = lambda set: f'/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv' NORMALIZE = True # %% print('OVERLAP') mkdo = tbl_dpkd_metrics_overlap(MODEL, TRAIN_DATA_P, NORMALIZE, plot=False) print('NUM MUTATIONS:') mkdnm = tbl_dpkd_metrics_n_mut(MODEL, NORMALIZE, plot=False) ```

2.2. Stratify by mutation count

image

2 classes "single mutation" vs "2+ mutations"

histogram ![image](https://github.com/jyaacoub/MutDTA/assets/50300488/8f9b5feb-6d1f-437f-8e0c-0df60a7d5e5e)
1 mutations 2+ mutations Sig
pcorr 0.076 $\pm$ 0.019 0.336 $\pm$ 0.047 *
scorr 0.053 $\pm$ 0.015 0.252 $\pm$ 0.043 *
mse 1.848 $\pm$ 0.038 1.328 $\pm$ 0.095 *
mae 0.961 $\pm$ 0.031 0.833 $\pm$ 0.028 *
rmse 1.359 $\pm$ 0.014 1.152 $\pm$ 0.041 *

3 classes "single", "2", "3+"

histogram ![image](https://github.com/jyaacoub/MutDTA/assets/50300488/bb347ccb-14cc-4fd4-b9f5-2860a99ed043)
1 mutations 2 mutations 3+ mutations
pcorr 0.076 $\pm$ 0.019 0.207 $\pm$ 0.055 0.509 $\pm$ 0.070
scorr 0.053 $\pm$ 0.015 0.131 $\pm$ 0.038 0.496 $\pm$ 0.078
mse 1.848 $\pm$ 0.038 1.586 $\pm$ 0.110 0.982 $\pm$ 0.140
mae 0.961 $\pm$ 0.031 0.964 $\pm$ 0.016 0.732 $\pm$ 0.022
rmse 1.359 $\pm$ 0.014 1.259 $\pm$ 0.043 0.989 $\pm$ 0.070
CODE ```python from src.analysis.figures import tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_n_mut MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv" TRAIN_DATA_P = lambda set: f'/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv' NORMALIZE = True print('NUM MUTATIONS:') mkdnm = tbl_dpkd_metrics_n_mut(MODEL, NORMALIZE, conditions=[1,2], plot=True) ```

2.3. Stratify by location of mutation

2.3.1. binding pocket vs not in binding pocket

histogram ![image](https://github.com/jyaacoub/MutDTA/assets/50300488/f3c0a51e-3804-4bda-a58b-110856fab023)
mutation in pocket mutation NOT in pocket Sig
pcorr 0.180 $\pm$ 0.029 0.110 $\pm$ 0.056 *
scorr 0.108 $\pm$ 0.022 0.050 $\pm$ 0.102
mse 1.640 $\pm$ 0.057 1.781 $\pm$ 0.113 *
mae 0.904 $\pm$ 0.021 0.981 $\pm$ 0.018 *
rmse 1.280 $\pm$ 0.022 1.334 $\pm$ 0.042 *
code

```python # %% import matplotlib.pyplot as plt import seaborn as sns from src.analysis.figures import get_dpkd NORMALIZE = True dfr = pd.read_csv(f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv", index_col=0) df = pd.read_csv("/cluster/home/t122995uhn/projects/data/PlatinumDataset/nomsa_binary_original_binary/full/cleaned_XY.csv", index_col=0).dropna() df = get_in_binding(df, dfr=dfr) fig = plt.figure(figsize=(14,5)) ax = fig.subplots(1,1) # must include 0 in both cases since they are the wildtype reference true_dpkd1 = get_dpkd(df.query('(pocket == 0) | (pocket == 2)'), 'pkd', NORMALIZE) sns.histplot(true_dpkd1, kde=True, ax=ax, alpha=0.6, color='orange', label='not in pocket', stat='proportion') true_dpkd1 = get_dpkd(df.query('(pocket == 0) | (pocket == 1)'), 'pkd', NORMALIZE) sns.histplot(true_dpkd1, kde=True, ax=ax, alpha=0.6, color=None, label='in pocket', stat='proportion') ax.set_title(f"{'Normalized 'if NORMALIZE else ''}TRUE Δpkd distribution") ax.set_xlabel('Δpkd') ax.legend() ```

2.3.2. distance to ligand

image

counts
wt 981
near lig 577
not near lig 372
mutation near lig (<4A) mutation not near lig (>4A) p-val
pcorr 0.164 $\pm$ 0.012 0.198 $\pm$ 0.017 0.1405
scorr 0.104 $\pm$ 0.009 0.079 $\pm$ 0.015 0.1844
mse 1.672 $\pm$ 0.023 1.604 $\pm$ 0.034 0.1405
mae 0.912 $\pm$ 0.010 0.939 $\pm$ 0.005 0.0408
rmse 1.293 $\pm$ 0.009 1.266 $\pm$ 0.013 0.1381
Distribution

![image](https://github.com/jyaacoub/MutDTA/assets/50300488/9e82d7f7-f3f2-4215-8dde-2de126c7f2c3)

Code

```python from src.analysis.figures import tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_n_mut, tbl_dpkd_metrics_in_binding, predictive_performance, tbl_stratified_dpkd_metrics from src.analysis.metrics import get_metrics from src import config as cfg import pandas as pd #%% MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv" RAW_PLT_CSV=f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv" NORMALIZE = True n_models=5 verbose=True plot=True dfr = pd.read_csv(RAW_PLT_CSV, index_col=0) dfp = pd.read_csv(MODEL(0), index_col=0) #%% import seaborn as sns sns.histplot(dfr['mut.distance_to_lig']) #%% # add in_binding info to df thres = 4 def get_in_binding(df, dfr): """ df is the predicted csv with index as _wt (or *_mt) where raw_idx corresponds to an index in dfr which contains the raw data for platinum including ('mut.in_binding_site') - 0: wildtype rows - 1: close (<8 Ang) - 2: Far (>8 Ang) """ near_lig = dfr[dfr['mut.distance_to_lig'] < thres].index pclass = [] for code in df.index: if '_wt' in code: pclass.append(0) elif int(code.split('_')[0]) in near_lig: pclass.append(1) else: pclass.append(2) df['near_lig'] = pclass return df conditions = ['(near_lig == 0) | (near_lig == 1)', '(near_lig == 0) | (near_lig == 2)'] names = [f'mutation near lig (<{thres}A)', f'mutation not near lig (>{thres}A)'] df = get_in_binding(dfp, dfr) if verbose: cnts = df.near_lig.value_counts() cnts.index = ['wt', 'near lig', 'not near lig'] cnts.name = "counts" print(cnts.to_markdown(), end="\n\n") #%% tbl_stratified_dpkd_metrics(MODEL, NORMALIZE, n_models=n_models, df_transform=get_in_binding, conditions=conditions, names=names, verbose=verbose, plot=plot, dfr=dfr) ```

jyaacoub commented 2 months ago

3. Significant Mutation impact analysis

With overlap best threshold is 0.1*STD:

Figures (ROC curve and sample confusion matrix)

![image](https://github.com/jyaacoub/MutDTA/assets/50300488/48acc8ff-e718-4d56-afa4-4b9af8b2d138) ![image](https://github.com/jyaacoub/MutDTA/assets/50300488/7b8e0fa3-fce9-41ca-84ce-5f47d7cbe9d3)

Without overlap best threshold is 0.3*STD

Figures (ROC curve and sample confusion matrix)

![image](https://github.com/jyaacoub/MutDTA/assets/50300488/35c9b63a-b6b8-4cda-882a-89bb25a85288) ![image](https://github.com/jyaacoub/MutDTA/assets/50300488/8e20e21e-4dec-452d-9304-4cc4cecce72b)

code ```python #%% from src.analysis.figures import get_dpkd, fig_sig_mutations_conf_matrix, generate_roc_curve from src.analysis.metrics import get_metrics import numpy as np import pandas as pd from scipy.stats import ttest_ind MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv" data_p = lambda set: f'/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv' df_t = pd.Index.append(pd.read_csv(data_p('train'), index_col=0).index, pd.read_csv(data_p('val'), index_col=0).index) df_t = df_t.str.upper() results_with_overlap = [] results_without_overlap = [] i=0 df = pd.read_csv(MODEL(i), index_col=0).dropna() df['pdb'] = df['prot_id'].str.split('_').str[0] df_no = df[~(df['pdb'].isin(df_t))] #%% true_dpkd = get_dpkd(df, pkd_col='actual') pred_dpkd = get_dpkd(df, pkd_col='pred') true_dpkd_no = get_dpkd(df_no, pkd_col='actual') pred_dpkd_no = get_dpkd(df_no, pkd_col='pred') # %% # ROC _, _, _, best_threshold = generate_roc_curve(true_dpkd, pred_dpkd, thres_range=(0,5), step=0.1) _ = fig_sig_mutations_conf_matrix(true_dpkd, pred_dpkd, std=round(best_threshold, 3)) # %% _, _, _, best_threshold = generate_roc_curve(true_dpkd_no, pred_dpkd_no, thres_range=(0,5), step=0.1) _ = fig_sig_mutations_conf_matrix(true_dpkd_no, pred_dpkd_no, std=round(best_threshold, 3)) ```
jyaacoub commented 1 month ago

Pretrained Davis results

outline

results:

Code

```python # %% import torch, os import pandas as pd from src import cfg from src import TUNED_MODEL_CONFIGS from src.utils.loader import Loader from src.train_test.training import test from src.analysis.figures import predictive_performance, tbl_stratified_dpkd_metrics, tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_in_binding device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") INFERENCE = False VERBOSE = True out_dir = f'{cfg.MEDIA_SAVE_DIR}/test_set_pred/' os.makedirs(out_dir, exist_ok=True) cp_dir = cfg.CHECKPOINT_SAVE_DIR RAW_PLT_CSV=f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv" #%% load up model: for KEY, CONFIG in TUNED_MODEL_CONFIGS.items(): MODEL_KEY = lambda fold: Loader.get_model_key(CONFIG['model'], CONFIG['dataset'], CONFIG['feature_opt'], CONFIG['edge_opt'], CONFIG['batch_size'], CONFIG['lr'], CONFIG['architecture_kwargs']['dropout'], n_epochs=2000, fold=fold, ligand_feature=CONFIG['lig_feat_opt'], ligand_edge=CONFIG['lig_edge_opt']) print('\n\n'+ '## ' + KEY) OUT_PLT = lambda i: f'{out_dir}/{MODEL_KEY(i)}_PLATINUM.csv' db_p = f"{CONFIG['feature_opt']}_{CONFIG['edge_opt']}_{CONFIG['lig_feat_opt']}_{CONFIG['lig_edge_opt']}" if CONFIG['dataset'] in ['kiba', 'davis']: db_p = f"DavisKibaDataset/{CONFIG['dataset']}/{db_p}" else: db_p = f"{CONFIG['dataset']}Dataset/{db_p}" train_p = lambda set: f"{cfg.DATA_ROOT}/{db_p}/{set}0/cleaned_XY.csv" if not os.path.exists(OUT_PLT(0)) or INFERENCE: print('running inference!') cp = lambda fold: f"{cp_dir}/{MODEL_KEY(fold)}.model" model = Loader.init_model(model=CONFIG["model"], pro_feature=CONFIG["feature_opt"], pro_edge=CONFIG["edge_opt"],**CONFIG['architecture_kwargs']) # load up platinum test db loaders = Loader.load_DataLoaders(cfg.DATA_OPT.platinum, pro_feature = CONFIG['feature_opt'], edge_opt = CONFIG['edge_opt'], ligand_feature = CONFIG['lig_feat_opt'], ligand_edge = CONFIG['lig_edge_opt'], datasets=['test']) for i in range(5): model.safe_load_state_dict(torch.load(cp(i), map_location=device)) model.to(device) model.eval() loss, pred, actual = test(model, loaders['test'], device, verbose=True) # saving as csv with columns code, pred, actual # get codes from test loader codes, pid = [b['code'][0] for b in loaders['test']], [b['prot_id'][0] for b in loaders['test']] df = pd.DataFrame({'prot_id': pid, 'pred': pred, 'actual': actual}, index=codes) df.index.name = 'code' df.to_csv(OUT_PLT(i)) # run platinum eval: print('\n### 1. predictive performance') mkdown = predictive_performance(OUT_PLT, train_p, verbose=VERBOSE, plot=False) print('\n### 2 Mutation impact analysis') print('\n#### 2.1 $\Delta pkd$ predictive performance') mkdn = tbl_dpkd_metrics_overlap(OUT_PLT, train_p, verbose=VERBOSE, plot=False) print('\n#### 2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)') m = tbl_dpkd_metrics_in_binding(OUT_PLT, RAW_PLT_CSV, verbose=VERBOSE, plot=False) ```

davis_gvpl_aflow

1. predictive performance

mean predictive performance
cindex 0.472 $\pm$ 0.026
pcorr -0.062 $\pm$ 0.050
scorr -0.082 $\pm$ 0.078
mse 2.123 $\pm$ 0.099
mae 1.122 $\pm$ 0.049
rmse 1.455 $\pm$ 0.034

2 Mutation impact analysis

2.1 $\Delta pkd$ predictive performance

with overlap without overlap p-val
pcorr 0.005 $\pm$ 0.012 0.005 $\pm$ 0.012 1
scorr 0.003 $\pm$ 0.025 0.003 $\pm$ 0.025 1
mse 1.990 $\pm$ 0.025 1.990 $\pm$ 0.025 1
mae 0.980 $\pm$ 0.013 0.980 $\pm$ 0.013 1
rmse 1.411 $\pm$ 0.009 1.411 $\pm$ 0.009 1

2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)

counts
wt 981
pckt 708
not pckt 241
mutation in pocket mutation NOT in pocket p-val
pcorr 0.004 $\pm$ 0.015 0.025 $\pm$ 0.032 0.5777
scorr -0.009 $\pm$ 0.026 0.067 $\pm$ 0.071 0.3433
mse 1.992 $\pm$ 0.030 1.950 $\pm$ 0.065 0.5777
mae 1.000 $\pm$ 0.009 0.981 $\pm$ 0.021 0.4374
rmse 1.411 $\pm$ 0.011 1.396 $\pm$ 0.023 0.5578

davis_gvpl

1. predictive performance

mean predictive performance
cindex 0.469 $\pm$ 0.007
pcorr -0.058 $\pm$ 0.020
scorr -0.095 $\pm$ 0.021
mse 2.114 $\pm$ 0.040
mae 1.148 $\pm$ 0.017
rmse 1.454 $\pm$ 0.014

2 Mutation impact analysis

2.1 $\Delta pkd$ predictive performance

with overlap without overlap p-val
pcorr 0.008 $\pm$ 0.014 0.008 $\pm$ 0.014 1
scorr 0.023 $\pm$ 0.011 0.023 $\pm$ 0.011 1
mse 1.983 $\pm$ 0.029 1.983 $\pm$ 0.029 1
mae 0.974 $\pm$ 0.011 0.974 $\pm$ 0.011 1
rmse 1.408 $\pm$ 0.010 1.408 $\pm$ 0.010 1

2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)

counts
wt 981
pckt 725
not pckt 256
mutation in pocket mutation NOT in pocket p-val
pcorr 0.012 $\pm$ 0.012 -0.008 $\pm$ 0.052 0.7244
scorr 0.014 $\pm$ 0.019 0.025 $\pm$ 0.024 0.7289
mse 1.977 $\pm$ 0.024 2.016 $\pm$ 0.105 0.7244
mae 0.985 $\pm$ 0.014 0.979 $\pm$ 0.023 0.8043
rmse 1.406 $\pm$ 0.009 1.418 $\pm$ 0.037 0.7599

davis_aflow

1. predictive performance

mean predictive performance
cindex 0.446 $\pm$ 0.018
pcorr -0.127 $\pm$ 0.059
scorr -0.172 $\pm$ 0.051
mse 2.253 $\pm$ 0.117
mae 1.236 $\pm$ 0.032
rmse 1.499 $\pm$ 0.039

2 Mutation impact analysis

2.1 $\Delta pkd$ predictive performance

with overlap without overlap p-val
pcorr -0.007 $\pm$ 0.010 -0.007 $\pm$ 0.010 1
scorr -0.019 $\pm$ 0.015 -0.019 $\pm$ 0.015 1
mse 2.015 $\pm$ 0.019 2.015 $\pm$ 0.019 1
mae 0.922 $\pm$ 0.024 0.922 $\pm$ 0.024 1
rmse 1.419 $\pm$ 0.007 1.419 $\pm$ 0.007 1

2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)

counts
wt 981
pckt 708
not pckt 241
mutation in pocket mutation NOT in pocket p-val
pcorr -0.013 $\pm$ 0.009 0.031 $\pm$ 0.024 0.117
scorr -0.039 $\pm$ 0.018 0.063 $\pm$ 0.030 0.0208
mse 2.027 $\pm$ 0.018 1.937 $\pm$ 0.048 0.117
mae 0.922 $\pm$ 0.023 0.951 $\pm$ 0.027 0.4358
rmse 1.424 $\pm$ 0.006 1.391 $\pm$ 0.017 0.1167

PDBbind_gvpl_aflow

1. predictive performance

mean predictive performance
cindex 0.641 $\pm$ 0.011
pcorr 0.415 $\pm$ 0.028
scorr 0.411 $\pm$ 0.039
mse 0.947 $\pm$ 0.034
mae 0.754 $\pm$ 0.009
rmse 0.972 $\pm$ 0.018

2 Mutation impact analysis

2.1 $\Delta pkd$ predictive performance

with overlap without overlap p-val
pcorr 0.176 $\pm$ 0.012 0.037 $\pm$ 0.035 0.0058
scorr 0.099 $\pm$ 0.009 0.046 $\pm$ 0.027 0.0974
mse 1.649 $\pm$ 0.024 1.927 $\pm$ 0.071 0.0058
mae 0.899 $\pm$ 0.013 1.014 $\pm$ 0.010 0.0001
rmse 1.284 $\pm$ 0.009 1.387 $\pm$ 0.025 0.005

2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)

counts
wt 981
pckt 708
not pckt 241
mutation in pocket mutation NOT in pocket p-val
pcorr 0.180 $\pm$ 0.013 0.110 $\pm$ 0.025 0.0374
scorr 0.108 $\pm$ 0.010 0.050 $\pm$ 0.046 0.2533
mse 1.640 $\pm$ 0.026 1.781 $\pm$ 0.050 0.0374
mae 0.904 $\pm$ 0.010 0.981 $\pm$ 0.008 0.0003
rmse 1.280 $\pm$ 0.010 1.334 $\pm$ 0.019 0.0365