jyaacoub / MutDTA

Improving the precision oncology pipeline by providing binding affinity purtubations predictions on a pirori identified cancer driver genes.
2 stars 2 forks source link

Grouped Cross-Validation #60

Closed jyaacoub closed 1 year ago

jyaacoub commented 1 year ago

Stratified cross-validation for PDBbind and Kiba is needed since the proteins don't show up in equal amounts unlike with Davis.

Relavant tds articles:

  1. perform stratified K-Fold cross-validation on a grouped dataset (an optimization problem)
  2. stratified cross-validation and it’s implementation in Python using Scikit-Learn (non-grouped data)

1. Grouped stratified K-Fold Cross-validation

From [1]:

the Scikit-Learn Python package provides support for this feature through the StratifiedGroupKFold function. Still, according to the documentation, this function performs a greedy assignment. We can take it a step further using an optimization algorithm.

Problem Model

Rows are groups and columns are classes (as per [1]'s definition)

In our case a group would be a single unique protein and our class can be 3 categories for low, medium and high pkd values.

2. Stratified cross-validation with Scikit-Learn

This method is not applicable in our case since it makes no considerations for the grouped nature of our dataset. Any individual protein MUST only exclusively belong to a single fold.

jyaacoub commented 1 year ago

Maybe it is better to just do what we have been doing for time considerations?

The main issue here is getting unequal size folds since some proteins appear once or twice while others appear thousands of times...

Maybe it would be alright if we make these considerations by keeping track of the counts for each fold and equally dividing them up?

jyaacoub commented 1 year ago

The new idea is to treat it like the partition problem but relaxing the need for it to be equal sets (since that is NP-complete)

Note: weight == counts == # of rows in the dataset. image See function on desmos. Note: weight == counts == # of rows in the dataset.

This translates to the following pseudocode for the scoring function that determines where we place the current protein selected: following scoring function:

score = fold.weight - abs(fold.weight/len(fold) - item.weight)
jyaacoub commented 1 year ago

So implemented this (see commit: https://github.com/jyaacoub/MutDTA/commit/1d220cb0b3d6d876b0d9c5f1a5b06af51766e0dc), but it works too well... which is suspect.


jyaacoub commented 1 year ago

It would be nice if I could animate what is going on. its hard to visualize what is going on.

jyaacoub commented 1 year ago

It would be nice if I could animate what is going on. its hard to visualize what is going on.

See comment: https://github.com/jyaacoub/MutDTA/commit/1d220cb0b3d6d876b0d9c5f1a5b06af51766e0dc#diff-7aa6f5acee0f196fb1fd47aed9a5262890ef778865b291eb6101fde81f472e5bR332

It works better without the abs(fold[1]/f_len - c) term

Without abs(fold[1]/f_len - c):

        Dataset: PDB
         # | num_prots  | total_count  | final_score
    Fold 0 |    755     |     3253     |    3253   
    Fold 1 |    756     |     3253     |    3253   
    Fold 2 |    758     |     3253     |    3253   
    Fold 3 |    758     |     3253     |    3253   
    Fold 4 |    758     |     3253     |    3252   

With abs(fold[1]/f_len - c):

        Dataset: PDB
         # | num_prots  | total_count  | final_score
    Fold 0 |    754     |     3253     | 3249.685676392573
    Fold 1 |    755     |     3253     | 3249.691390728477
    Fold 2 |    758     |     3253     | 3249.708443271768
    Fold 3 |    759     |     3253     | 3249.714097496706
    Fold 4 |    759     |     3253     | 3248.7097625329816
jyaacoub commented 1 year ago

Code to generate distribution plots:

From https://github.com/jyaacoub/MutDTA/commit/873ef2ffdc0b67d85f3bb0c2d5529320ce7fe06a#diff-012a5034ce7c82b8e203b12a6e5fed17fc07ca43df38595c0cf78ef26ebe8806.

# %%
from collections import Counter, OrderedDict
import pandas as pd

data = 'davis0'
data = 'kiba'
data = 'pdb'
df = pd.read_csv(f'../data/misc/{data}_XY.csv', index_col=0)

prot_counts = Counter(df['prot_id'])

########## split remaining proteins into k_folds ##########
# Steps for this basically follow Greedy Number Partitioning
# 1. Initialize variables:
#       - folds = list of lists of protein ids
#       - prots = list of proteins  
# 2. Sort protein counts by number of samples
k = 5

folds = [[[], 0, -1] for i in range(k)] # tuple of (list of proteins, total weight, current-score)
score_history_f1 = [folds[0][2]]
#   - score = fold.weight - abs(fold.weight/len(fold) - item.weight)
# the most optimal data structure for folds is a list since their score 
# must be updated every time a protein is added to a fold
counts_sorted = sorted(list(prot_counts.items()), key=lambda x: x[1], reverse=True)
for p, c in counts_sorted:
    # Update scores for each fold
    for fold in folds:
        f_len = len(fold[0])
        if f_len == 0:

        # calculate score for adding protein to fold
        fold[2] = fold[1] - abs(fold[1]/f_len - c) # without this term it performs better

    # Finding optimal fold to add protein to (minimize score)
    best_fold = min(folds, key=lambda x: x[2])

    # Add protein to fold
    # update weight
    best_fold[1] += c

    # update score history

# %% convert folds to set after done selecting for faster lookup
folds_sets = [set(f[0]) for f in folds]

# validating that they dont intersect with each other
for i in range(len(folds_sets)):
    for j in range(i+1, len(folds_sets)):
        assert len(folds_sets[i].intersection(folds_sets[j])) == 0, "Folds intersect"

print("\t\tDataset:", data.upper())
print(f'{"#":>10} | {"num_prots":^10} | {"total_count":^12} | {"final_score":^10}')
for i,f in enumerate(folds): print(f'{"Fold "+str(i):>10} | {len(f[0]):^10} | {f[1]:^12} | {f[2]:^10}')
# %% Plotting distributions of each fold over each other
import matplotlib.pyplot as plt
import seaborn as sns

# x-axis will be the protein counts
# y-axis will be the number of proteins with that count

# get protein counts for each fold
counts = [Counter(df[df['prot_id'].isin(f[0])]['prot_id']) for f in folds]

# %% plot
bin_range = range(0, 100, 2)
kde = True
if data == 'kiba':
    bin_range = range(0, 1400, 100)
elif data == 'davis0':
    bin_range = [67,69] # 68 is the only count for davis
    kde = False

ax = sns.histplot(counts, stat='count', bins=bin_range, kde=kde, alpha=0.3,
# limit x-axis to 100
if data == 'pdb':
    ax.set_xlim([0, 20])
ax.set_title(f'Protein Count Distribution for {data.upper()}')
ax.set_xlabel('Count of protein')

# %%