jyaacoub / MutDTA

Improving the precision oncology pipeline by providing binding affinity purtubations predictions on a pirori identified cancer driver genes.
1 stars 2 forks source link

DeepDTA+ overlap in proteins in train and test #11

Closed jyaacoub closed 1 year ago

jyaacoub commented 1 year ago

This is an issue for the DeepDTA paper and any follow-ups (including DGraphDTA) that depend on the Kiba and Davis datasets provided by them here: https://github.com/hkmztrk/DeepDTA/blob/master/data/README.md.

There is a big overlap between the proteins provided in the training set and those in the test set (100% overlap). This is an issue since it becomes impossible to detect overfitting and the results are no longer a true indication of real-world model performance since the model has already been exposed to those proteins beforehand.

To see the overlap run the following:

# %%
import pickle, json
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
dataset = 'davis'
Y = pickle.load(open(f'/home/jyaacoub/projects/data/davis_kiba/{dataset}/Y', "rb"), encoding='latin1')
row_i, col_i = np.where(np.isnan(Y)==False)
test_fold = json.load(open(f"/home/jyaacoub/projects/data/davis_kiba/{dataset}/folds/test_fold_setting1.txt"))
train_fold = json.load(open(f"/home/jyaacoub/projects/data/davis_kiba/{dataset}/folds/train_fold_setting1.txt"))
train_flat = [i for fold in train_fold for i in fold]
test_protein_indices = col_i[test_fold]
train_protein_indices = col_i[train_flat]

# %% Overlap in train and test...
overlap = set(train_protein_indices).intersection(set(test_protein_indices))
print(f'number of unique proteins in train: {len(set(train_protein_indices))}')
print(f'number of unique proteins in test:  {len(set(test_protein_indices))}')
print(f'total number of unique proteins:    {max(col_i)+1}')
print(f'Intersection of train and test protein indices: {len(overlap)}')

#%% counts of overlaping proteins
test_counts = Counter(test_protein_indices)
train_counts = Counter(train_protein_indices)

overlap_test_counts = {k: test_counts[k] for k in overlap}
overlap_train_counts = {k: train_counts[k] for k in overlap}

# normalized for set size
norm_overlap_test_counts = {k: v/len(test_protein_indices) for k,v in overlap_test_counts.items()}
norm_overlap_train_counts = {k: v/len(train_protein_indices) for k,v in overlap_train_counts.items()}

#%% plot overlap counts
plt.figure(figsize=(15,10))
plt.subplot(2,1,1)
plt.bar(overlap_train_counts.keys(), overlap_train_counts.values(), label='train', width=1.0)
plt.bar(overlap_test_counts.keys(), overlap_test_counts.values(), label='test', width=1.0)
plt.xlabel('protein index')
plt.ylabel('count')
plt.title(f'Overlap of proteins in train and test ({dataset})')
plt.legend()

plt.subplot(2,1,2)
plt.bar(norm_overlap_train_counts.keys(), norm_overlap_train_counts.values(), label='train', width=1.0)
plt.bar(norm_overlap_test_counts.keys(), norm_overlap_test_counts.values(), label='test', width=1.0)
plt.xlabel('protein index')
plt.ylabel('normalized counts')
plt.title(f'Normalized overlap of proteins in train and test ({dataset})')
plt.legend()
plt.tight_layout()
plt.show()

Kiba train test overlap:

image

Davis train test overlap:

image