jyaacoub / MutDTA

Improving the precision oncology pipeline by providing binding affinity purtubations predictions on a pirori identified cancer driver genes.
1 stars 2 forks source link

stratify performance by protein type #57

Closed jyaacoub closed 8 months ago

jyaacoub commented 10 months ago

Kinases in davis and kiba might just be easier to learn than PDBbind: http://207.254.60.56/web/current/kinbase/genes/SpeciesID/9606/

Get full list of HUMAN kinase proteins (~538): http://207.254.60.56/web/current/kinbase/gene-sequences/

Match with gene names of davis and kiba proteins -https://staff.cs.utu.fi/~aatapa/data/DrugTarget/target_gene_names.txt

jyaacoub commented 10 months ago

From https://github.com/jyaacoub/MutDTA/commit/2d71560d0a3c61f796dea83f93a181869441a53c we have a distribution of the protein types: image

Missing proteins:

However, this is not a complete list since a bunch failed to match (169 proteins missing from FASTA)

The missing protein names are:

MISSING: ACVR1-ACVR1
MISSING: ACVR1B-ACVR1B
MISSING: ACVR2A-ACVR2A
MISSING: ACVR2B-ACVR2B
MISSING: ACVRL1-ACVRL1
MISSING: ANKK1-ANKK1
MISSING: ARK5-ARK5
MISSING: ASK1-ASK1
MISSING: ASK2-ASK2
MISSING: AURKA-AURKA
MISSING: AURKB-AURKB
MISSING: AURKC-AURKC
MISSING: CAMK1-CAMK1
MISSING: CAMK1D-CAMK1D
MISSING: CAMK1G-CAMK1G
MISSING: CAMK2A-CAMK2A
MISSING: CAMK2B-CAMK2B
MISSING: CAMK2D-CAMK2D
MISSING: CAMK2G-CAMK2G
MISSING: CAMK4-CAMK4
MISSING: CAMKK1-CAMKK1
MISSING: CAMKK2-CAMKK2
MISSING: CDC2L1-CDC2L1
MISSING: CDC2L2-CDC2L2
MISSING: CDC2L5-CDC2L5
MISSING: CHEK1-CHEK1
MISSING: CHEK2-CHEK2
MISSING: CIT-CIT
MISSING: CSF1R-CSF1R
MISSING: CSNK1A1-CSNK1A1
MISSING: CSNK1A1L-CSNK1A1L
MISSING: CSNK1D-CSNK1D
MISSING: CSNK1E-CSNK1E
MISSING: CSNK1G1-CSNK1G1
MISSING: CSNK1G2-CSNK1G2
MISSING: CSNK1G3-CSNK1G3
MISSING: CSNK2A1-CSNK2A1
MISSING: CSNK2A2-CSNK2A2
MISSING: DCAMKL1-DCAMKL1
MISSING: DCAMKL2-DCAMKL2
MISSING: DCAMKL3-DCAMKL3
MISSING: DMPK-DMPK
MISSING: EIF2AK1-EIF2AK1
MISSING: EPHA1-EPHA1
MISSING: EPHA2-EPHA2
MISSING: EPHA3-EPHA3
MISSING: EPHA4-EPHA4
MISSING: EPHA5-EPHA5
MISSING: EPHA6-EPHA6
MISSING: EPHA7-EPHA7
MISSING: EPHA8-EPHA8
MISSING: EPHB1-EPHB1
MISSING: EPHB2-EPHB2
MISSING: EPHB3-EPHB3
MISSING: EPHB4-EPHB4
MISSING: EPHB6-EPHB6
MISSING: ERBB2-ERBB2
MISSING: ERBB3-ERBB3
MISSING: ERBB4-ERBB4
MISSING: ERK1-ERK1
MISSING: ERK2-ERK2
MISSING: ERK3-ERK3
MISSING: ERK4-ERK4
MISSING: ERK5-ERK5
MISSING: ERK8-ERK8
MISSING: ERN1-ERN1
MISSING: GRK1-GRK1
MISSING: GRK4-GRK4
MISSING: GRK7-GRK7
MISSING: INSRR-INSRR
MISSING: MAP3K15-MAP3K15
MISSING: MAP4K2-MAP4K2
MISSING: MAP4K3-MAP4K3
MISSING: MAP4K4-MAP4K4
MISSING: MAP4K5-MAP4K5
MISSING: MEK1-MEK1
MISSING: MEK2-MEK2
MISSING: MEK3-MEK3
MISSING: MEK4-MEK4
MISSING: MEK5-MEK5
MISSING: MEK6-MEK6
MISSING: MERTK-MERTK
MISSING: MKK7-MKK7
MISSING: MKNK1-MKNK1
MISSING: MKNK2-MKNK2
MISSING: MLCK-MLCK
MISSING: MRCKA-MRCKA
MISSING: MRCKB-MRCKB
MISSING: MST1R-MST1R
MISSING: MTOR-MTOR
MISSING: MYLK-MYLK
MISSING: MYLK2-MYLK2
MISSING: MYLK4-MYLK4
MISSING: PAK7-PAK7
MISSING: PCTK1-PCTK1
MISSING: PCTK2-PCTK2
MISSING: PCTK3-PCTK3
MISSING: PDGFRA-PDGFRA
MISSING: PDGFRB-PDGFRB
MISSING: PDPK1-PDPK1
MISSING: PFCDPK1(Pfalciparum)-PFCDPK1
MISSING: PFPK5(Pfalciparum)-PFPK5
MISSING: PFTK1-PFTK1
MISSING: PHKG1-PHKG1
MISSING: PHKG2-PHKG2
MISSING: PIK3C2B-PIK3C2B
MISSING: PIK3C2G-PIK3C2G
MISSING: PIK3CA-PIK3CA
MISSING: PIK3CA(C420R)-PIK3CA
MISSING: PIK3CA(E542K)-PIK3CA
MISSING: PIK3CA(E545A)-PIK3CA
MISSING: PIK3CA(E545K)-PIK3CA
MISSING: PIK3CA(H1047L)-PIK3CA
MISSING: PIK3CA(H1047Y)-PIK3CA
MISSING: PIK3CA(I800L)-PIK3CA
MISSING: PIK3CA(M1043I)-PIK3CA
MISSING: PIK3CA(Q546K)-PIK3CA
MISSING: PIK3CB-PIK3CB
MISSING: PIK3CD-PIK3CD
MISSING: PIK3CG-PIK3CG
MISSING: PIK4CB-PIK4CB
MISSING: PIP5K1A-PIP5K1A
MISSING: PIP5K1C-PIP5K1C
MISSING: PIP5K2B-PIP5K2B
MISSING: PIP5K2C-PIP5K2C
MISSING: PKMYT1-PKMYT1
MISSING: PKNB(Mtuberculosis)-PKNB
MISSING: PRKCD-PRKCD
MISSING: PRKCE-PRKCE
MISSING: PRKCH-PRKCH
MISSING: PRKCI-PRKCI
MISSING: PRKCQ-PRKCQ
MISSING: PRKD1-PRKD1
MISSING: PRKD2-PRKD2
MISSING: PRKD3-PRKD3
MISSING: PRKG1-PRKG1
MISSING: PRKG2-PRKG2
MISSING: PRKR-PRKR
MISSING: RIPK4-RIPK4
MISSING: RIPK5-RIPK5
MISSING: ROS1-ROS1
MISSING: RPS6KA4(KinDom.1-N-terminal)-RPS6KA4
MISSING: RPS6KA4(KinDom.2-C-terminal)-RPS6KA4
MISSING: RPS6KA5(KinDom.1-N-terminal)-RPS6KA5
MISSING: RPS6KA5(KinDom.2-C-terminal)-RPS6KA5
MISSING: S6K1-S6K1
MISSING: SBK1-SBK1
MISSING: SIK2-SIK2
MISSING: SNARK-SNARK
MISSING: SRMS-SRMS
MISSING: SRPK3-SRPK3
MISSING: STK16-STK16
MISSING: STK35-STK35
MISSING: STK36-STK36
MISSING: STK39-STK39
MISSING: TAOK1-TAOK1
MISSING: TAOK2-TAOK2
MISSING: TAOK3-TAOK3
MISSING: TGFBR1-TGFBR1
MISSING: TGFBR2-TGFBR2
MISSING: TNK2-TNK2
MISSING: TNNI3K-TNNI3K
MISSING: TRPM6-TRPM6
MISSING: TSSK1B-TSSK1B
MISSING: VEGFR2-VEGFR2
MISSING: WEE1-WEE1
MISSING: WEE2-WEE2
MISSING: YSK4-YSK4
jyaacoub commented 10 months ago

For the missing proteins we need to just go through each one and identify them manually since some of the aliases are missing from kinbase

jyaacoub commented 10 months ago

After a painstaking 2 hours of looking up each protein I got the following list to map them to the correct alias that matches the name in the FASTA file from kinbase:

missing_prots = {
    # TGF superfamily
    'ACVR1':  'ALK2',       # https://www.ncbi.nlm.nih.gov/gene/90
    'ACVR1B': 'ALK4',       # https://www.ncbi.nlm.nih.gov/gene/91
    'ACVR2A': 'ACTR2',      # https://www.ncbi.nlm.nih.gov/gene/92
    'ACVR2B': 'ACTR2B',     # https://www.ncbi.nlm.nih.gov/gene/93
    'ACVRL1': 'ALK1',       # https://www.ncbi.nlm.nih.gov/gene/94

    'ANKK1': 'sgk288',      # https://www.ncbi.nlm.nih.gov/gene/255239

    'ARK5': 'NuaK1',        # https://www.ncbi.nlm.nih.gov/gene/9891

    'ASK1': 'MAP3K5',       # https://en.wikipedia.org/wiki/ASK1
    'ASK2': 'MAP3K6',       # https://www.ncbi.nlm.nih.gov/gene/9064

    'AURKA': 'AurA',        # https://www.ncbi.nlm.nih.gov/gene/6790
    'AURKB': 'AurB',        # https://www.ncbi.nlm.nih.gov/gene/9212
    'AURKC': 'AurC',        # https://www.ncbi.nlm.nih.gov/gene/6795

    'CDC2L1':   'CDK11',    # https://en.wikipedia.org/wiki/CDC2L1
    'CDC2L2':   'CDK11',    # https://www.ncbi.nlm.nih.gov/gene/493708
    'CDC2L5':   'CHED',     # https://www.ncbi.nlm.nih.gov/gene/8621

    'CHEK1':    'CHK1',     # https://www.ncbi.nlm.nih.gov/gene/1111
    'CHEK2':    'CHK2',     # https://www.ncbi.nlm.nih.gov/gene/11200

    'CIT':      'CRIK',     # https://www.ncbi.nlm.nih.gov/gene/11113
    'CSF1R':    'FMS',      # https://www.ncbi.nlm.nih.gov/gene/1436

    'CSNK1A1':  'CK1a',     # https://www.ncbi.nlm.nih.gov/gene/1452
    'CSNK1A1L': 'CK1',      # https://www.ncbi.nlm.nih.gov/gene/122011
    'CSNK1D':   'CK1d',     # https://www.ncbi.nlm.nih.gov/gene/1453
    'CSNK1E':   'CK1e',     # https://www.ncbi.nlm.nih.gov/gene/1454
    'CSNK1G1':  'CK1g1',    # https://www.ncbi.nlm.nih.gov/gene/53944
    'CSNK1G2':  'CK1g2',    # https://www.ncbi.nlm.nih.gov/gene/1455
    'CSNK1G3':  'CK1g3',    # https://www.ncbi.nlm.nih.gov/gene/1456
    'CSNK2A1':  'CK2a1',    # https://www.ncbi.nlm.nih.gov/gene/1457
    'CSNK2A2':  'CK2a2',    # https://www.ncbi.nlm.nih.gov/gene/1459

    'DCAMKL1':  'DCLK1',    # https://www.ncbi.nlm.nih.gov/gene/9201
    'DCAMKL2':  'DCLK2',    # https://www.ncbi.nlm.nih.gov/gene/166614
    'DCAMKL3':  'DCLK3',    # https://www.ncbi.nlm.nih.gov/gene/85443

    'EIF2AK1':  'HRI',      # https://www.ncbi.nlm.nih.gov/gene/27102
    'ERK8':     'ERK7',     # https://www.ncbi.nlm.nih.gov/gene/225689
    'ERN1':     'IRE1',     # https://www.ncbi.nlm.nih.gov/gene/2081

    'GRK1':     'RHOK',     # https://www.ncbi.nlm.nih.gov/gene/6011
    'GRK4':     'GPRK4',    # https://www.ncbi.nlm.nih.gov/gene/2868
    'GRK7':     'GPRK7',    # https://www.ncbi.nlm.nih.gov/gene/131890

    'INSRR':    'IRR',      # https://www.ncbi.nlm.nih.gov/gene/3645
    'MAP3K15':  'MAP3K2',   # https://www.ncbi.nlm.nih.gov/gene/389840 **Unsure on this one
    'MAP4K2':   'GCK',      # https://www.ncbi.nlm.nih.gov/gene/5871
    'MAP4K3':   'HGK',      # https://www.ncbi.nlm.nih.gov/gene/8491 **
    'MAP4K4':   'HGK',      # https://www.ncbi.nlm.nih.gov/gene/9448
    'MAP4K5':   'KHS1',     # https://www.ncbi.nlm.nih.gov/gene/11183

    'MEK1':     'Erk1',     # https://www.ncbi.nlm.nih.gov/gene/828713
    'MEK2':     'Erk2',     # *
    'MEK3':     'Erk3',     # *
    'MEK4':     'Erk4',     # *
    'MEK5':     'Erk5',     # *
    'MEK6':     'Erk7',     # *

    'MERTK':    'MER',      # https://www.ncbi.nlm.nih.gov/gene/10461
    'MKK7':     'MAP2K7',   # https://www.ncbi.nlm.nih.gov/gene/5609

    'MKNK1':    'MNK1',     # https://www.ncbi.nlm.nih.gov/gene/8569
    'MKNK2':    'MNK2',     # https://www.ncbi.nlm.nih.gov/gene/2872

    'MST1R':    'RON',      # https://www.ncbi.nlm.nih.gov/gene/4486
    'MTOR':     'FRAP',     # https://www.ncbi.nlm.nih.gov/gene/2475

    'MYLK':     'smMLCK',   # https://www.ncbi.nlm.nih.gov/gene/4638
    'MYLK2':    'skMLCK',   # https://www.ncbi.nlm.nih.gov/gene/85366
    'MYLK4':    'SgK085',   # https://www.ncbi.nlm.nih.gov/gene/340156

    'PAK7':     'PAK5',     # https://www.ncbi.nlm.nih.gov/gene/57144

    'PCTK1':    'PCTAIRE1', # https://www.ncbi.nlm.nih.gov/gene/5127
    'PCTK2':    'PCTAIRE2', # https://www.ncbi.nlm.nih.gov/gene/5128
    'PCTK3':    'PCTAIRE3', # https://www.ncbi.nlm.nih.gov/gene/5129

    'PDPK1':    'PDK1',     # https://www.ncbi.nlm.nih.gov/gene/5170
    'PFCDPK1':  'Other',  # https://www.ncbi.nlm.nih.gov/gene/815931 ** discontinued?
    'PFPK5':    'CDK2',     # High sequence similarity to CDK2 (https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6436633/)
    'PFTK1':    'PFTAIRE1', # https://www.ncbi.nlm.nih.gov/gene/5218

    'PIK3C2B':  'PIK3R4',   # https://www.ncbi.nlm.nih.gov/gene/5287 **
    'PIK3C2G':  'PIK3R4',   # https://www.ncbi.nlm.nih.gov/gene/5288 **
    'PIK3CA':   'PIK3R4',   # *
    'PIK3CB':   'PIK3R4',   # *
    'PIK3CD':   'PIK3R4',   # *
    'PIK3CG':   'PIK3R4',   # *
    'PIK4CB':   'PIK3R4',   # *

    'PIP5K1A':  'Other',    # https://www.ncbi.nlm.nih.gov/gene/8394 ** FAILED TO FIND ALIAS
    'PIP5K1C':  'Other',
    'PIP5K2B':  'Other',
    'PIP5K2C':  'Other',

    'PKMYT1':   'MYT1',     # https://www.ncbi.nlm.nih.gov/gene/9088
    'PKNB':     'Other',    # https://www.ncbi.nlm.nih.gov/gene/887072 ** FAILED TO FIND ALIAS

    'PRKCD':    'PKCd',     # https://www.ncbi.nlm.nih.gov/gene/5580
    'PRKCE':    'PKCe',     # https://www.ncbi.nlm.nih.gov/gene/5581
    'PRKCH':    'PKCh',     # https://www.ncbi.nlm.nih.gov/gene/5583
    'PRKCI':    'PKCi',     # https://www.ncbi.nlm.nih.gov/gene/5584
    'PRKCQ':    'PKCt',     # https://www.ncbi.nlm.nih.gov/gene/5588 *

    'PRKD1':    'PKD1',     # https://www.ncbi.nlm.nih.gov/gene/5587
    'PRKD2':    'PKD2',     # https://www.ncbi.nlm.nih.gov/gene/25865
    'PRKD3':    'PKD3',     # https://www.ncbi.nlm.nih.gov/gene/23683

    'PRKG1':    'PKG1',     # https://www.ncbi.nlm.nih.gov/gene/5592
    'PRKG2':    'PKG2',     # https://www.ncbi.nlm.nih.gov/gene/5593

    'PRKR':     'PKR',      # https://www.ncbi.nlm.nih.gov/gene/5610

    'RIPK4':    'ANKRD3',   # https://www.ncbi.nlm.nih.gov/gene/54101
    'RIPK5':    'ANKRD3',   # https://www.ncbi.nlm.nih.gov/gene/11035 ** FAILED TO FIND ALIAS

    'ROS1':     'ROS',      # https://www.ncbi.nlm.nih.gov/gene/6098

    'RPS6KA4':  'RSK3',     # https://www.ncbi.nlm.nih.gov/gene/8986
    'RPS6KA4':  'RSK1~b',   # https://www.ncbi.nlm.nih.gov/gene/8986
    'RPS6KA5':  'MSK1',     # https://www.ncbi.nlm.nih.gov/gene/9252

    'S6K1':     'p70S6K',   # https://www.ncbi.nlm.nih.gov/gene/6198
    'SBK1':     'SBK',      # https://www.ncbi.nlm.nih.gov/gene/388228
    'SIK2':     'QIK',      # https://www.ncbi.nlm.nih.gov/gene/23235
    'SNARK':    'NUAK2',    # https://www.ncbi.nlm.nih.gov/gene/81788
    'SRMS':     'SRM',      # https://www.ncbi.nlm.nih.gov/gene/6725
    'SRPK3':    'MSSK1',    # https://www.ncbi.nlm.nih.gov/gene/26576

    'STK16':    'MPSK1',    # https://www.ncbi.nlm.nih.gov/gene/8576
    'STK35':    'CLIK1',    # https://www.ncbi.nlm.nih.gov/gene/140901
    'STK36':    'CLIK1',    # https://www.ncbi.nlm.nih.gov/gene/27148 ** FAILED TO FIND ALIAS
    'STK39':    'PASK',     # https://www.ncbi.nlm.nih.gov/gene/27347

    'TAOK1':    'TAO1',     # https://www.ncbi.nlm.nih.gov/gene/57551
    'TAOK2':    'TAO2',     # https://www.ncbi.nlm.nih.gov/gene/9344
    'TAOK3':    'TAO3',     # https://www.ncbi.nlm.nih.gov/gene/51347

    'TNK2':     'ACK',      # https://www.ncbi.nlm.nih.gov/gene/10188
    'TNNI3K':   'p38a',     # https://www.ncbi.nlm.nih.gov/gene/51208 ** FAILED TO FIND ALIAS

    'TRPM6':    'ChaK2',    # https://www.ncbi.nlm.nih.gov/gene/140803
    'TSSK1B':   'TSSK1',    # https://www.ncbi.nlm.nih.gov/gene/83942
    'VEGFR2':   'KDR',      # https://www.ncbi.nlm.nih.gov/gene/3791
    'WEE2':     'Wee1B',    # https://www.ncbi.nlm.nih.gov/gene/494551
    'YSK4':     'MAP3K1',   # https://www.ncbi.nlm.nih.gov/gene/80122 ** FAILED TO FIND ALIAS
}
jyaacoub commented 10 months ago

image

CODE

Include missing prots from above:

#%%
import seaborn as sns
import pandas as pd
import json
import matplotlib.pyplot as plt
from src.data_analysis.stratify_protein import check_davis_names, kinbase_to_df
from src.utils import config as cfg

df = kinbase_to_df()

#%%
missing_prot_merged = {}
for k, name in missing_prots.items():
    matches = df.index[df.index.str.lower() == name.lower()]

    if len(matches) > 0:
        name = matches[0]
        missing_prot_merged[k] = (df.loc[name, 'main_family'], df.loc[name, 'subgroup'], None)
    else:
        missing_prot_merged[k] = (name, None, None)

# %%
missing_df = pd.DataFrame.from_dict(missing_prot_merged, orient='index', 
                                    columns=df.columns)

#%%
df = pd.concat([df, missing_df])

#%%

prot_dict = json.load(open(f'{cfg.DATA_ROOT}/davis/proteins.txt', 'r'))
prots = check_davis_names(prot_dict, df)

# %% plot histogram of main families and their counts
main_families = [v[1] for v in prots.values()]
main_families = pd.Series(main_families).value_counts().sort_values(ascending=False)
sns.set_theme(style='darkgrid')
plt.figure(figsize=(10, 5))
sns.barplot(x=main_families.index, y=main_families.values, color='b')
plt.xlabel('Protein Kinase Family')
plt.title('Distribution of Protein Kinase Families in Davis Dataset')
plt.ylabel('Count (442 total proteins)')
#plt.savefig('results/figures/davis_kinaseFamilies.png', dpi=300, bbox_inches='tight')
jyaacoub commented 10 months ago

The following plot shows the differences between EDI and DG models stratified by protein families: image

Code:

#%%
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from src.data_analysis.stratify_protein import map_davis_to_kinbase

# Get kinbase data to map to family names
kin_df = pd.read_csv('../data/misc/kinase_base_updated.csv', index_col='name')

#%%
subgroups_to_plot = ['TK', 'STE', 'Other', 'CAMK', 'AGC']
models_to_plot = ['EDI', 'DG']

fig, axes = plt.subplots(len(subgroups_to_plot)+1, len(models_to_plot), 
                       figsize=(5*len(models_to_plot), 4*(len(subgroups_to_plot)+1)))
for i, model_type in enumerate(models_to_plot):
    if model_type == 'EDI':
        model_path = lambda x: f'results/model_media/test_set_pred/EDIM_davis{x}D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_testPred.csv'
    elif model_type == 'DG':
        model_path = lambda x: f'results/model_media/test_set_pred/DGM_davis{x}D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_testPred.csv'

    # Do the same but this time with error bars by using cross validation
    # data will be a dict of {main_family: [mse1, mse2, ...], ...}
    data_main = {}
    data_subgroups = {} # {main_family: {subgroup: [mse1, mse2, ...], ...}, ...}

    for fold in range(5):
        pred = pd.read_csv(model_path(fold), index_col='name')

        # returns a dict of {davis_name: (kinbase_name, main_family, subgroup)}
        pred_kb = map_davis_to_kinbase(pred.index.unique(), df=kin_df) # should be the same for all folds (same test set)

        # update pred to have kinbase info
        pred['kinbase_name'] = pred.index.map(lambda x: pred_kb[x][0])
        pred['main_family'] = pred.index.map(lambda x: pred_kb[x][1])
        pred['subgroup'] = pred.index.map(lambda x: pred_kb[x][2])

        for f in pred.main_family.unique():
            matched = pred[pred.main_family == f]
            mse = ((matched.pred - matched.actual)**2).mean()

            # add main family mse to dict
            data_main[f] = data_main.get(f, []) + [mse]

            # add main_family subgroup mse to dict
            data_subgroups[f] = data_subgroups.get(f, {})

            for g in matched.subgroup.unique():
                g_matched = matched[matched.subgroup == g]
                mse = ((g_matched.pred - g_matched.actual)**2).mean()
                data_subgroups[f][g] = data_subgroups[f].get(g, []) + [mse]

    # plot mse as bar chart
    plot_df = pd.DataFrame(data_main)
    curr_ax = axes[0, i]
    sns.barplot(data=plot_df, ax=curr_ax)
    curr_ax.set_ylabel(f'MSE')
    curr_ax.set_xlabel('Main Family')
    curr_ax.set_title(f'MSE loss for {model_type}M by Protein Family')
    curr_ax.set_ylim(0, 1.4)

    for j, f in enumerate(subgroups_to_plot):
        curr_ax = axes[j+1, i]

        sns.barplot(data=pd.DataFrame(data_subgroups[f]), ax=curr_ax)
        if i == 0:
            curr_ax.set_ylabel('MSE')
        curr_ax.set_xlabel(f'{f} Subgroups')
        curr_ax.set_ylim(0, 1.4)

plt.tight_layout()
jyaacoub commented 10 months ago

Some interesting results when plotting MSE by protein subgroup. Here we see the opposite of what we might expect. The more higher the protein count the worse the model does in prediction in the test set!

Also shows evidence of ESM overfitting more. image image

Code:

#%%
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from src.data_analysis.stratify_protein import map_davis_to_kinbase

# Get kinbase data to map to family names
kin_df = pd.read_csv('../data/misc/kinase_base_updated.csv', index_col='name')
sns.set_style('darkgrid')

# %%
# get count of every subgroup
# cols are code,SMILE,prot_seq,pkd,prot_id
df_full = pd.read_csv('../data/DavisKibaDataset/davis/nomsa_binary_original_binary/full/XY.csv',
                      index_col='code')

prot_counts = df_full.index.value_counts()
kb_dict = map_davis_to_kinbase(df_full.index.unique(), 
                               df=kin_df) # should be the same for all folds (same test set)
kb_df = pd.DataFrame.from_dict(kb_dict, orient='index', 
                               columns=['kinbase', 'main_family', 'subgroup'])

subgroup_counts = {} # {subgroup: count, ...}
for idx in df_full.index:
    subgrp = kb_df.loc[idx, 'subgroup']
    subgroup_counts[subgrp] = subgroup_counts.get(subgrp, 0) + 1

# %% get subgroup mse
model_type = 'DG'
for model_type in ['EDI', 'DG']:
    if model_type == 'EDI':
        model_path = lambda x: f'results/model_media/test_set_pred/EDIM_davis{x}D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_testPred.csv'
    elif model_type == 'DG':
        model_path = lambda x: f'results/model_media/test_set_pred/DGM_davis{x}D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_testPred.csv'

    # Do the same but this time with error bars by using cross validation
    # data will be a dict of {main_family: [mse1, mse2, ...], ...}
    data_main = {}
    data_subgroups = {} # {main_family: {subgroup: [mse1, mse2, ...], ...}, ...}

    for fold in range(5):
        pred = pd.read_csv(model_path(fold), index_col='name')

        # returns a dict of {davis_name: (kinbase_name, main_family, subgroup)}
        kb_dict = map_davis_to_kinbase(pred.index.unique(), df=kin_df) # should be the same for all folds (same test set)

        # update pred to have kinbase info
        pred['kinbase_name'] = pred.index.map(lambda x: kb_dict[x][0])
        pred['main_family'] = pred.index.map(lambda x: kb_dict[x][1])
        pred['subgroup'] = pred.index.map(lambda x: kb_dict[x][2])

        for f in pred.main_family.unique():
            matched = pred[pred.main_family == f]
            mse = ((matched.pred - matched.actual)**2).mean()

            # add main family mse to dict
            data_main[f] = data_main.get(f, []) + [mse]

            # add main_family subgroup mse to dict
            data_subgroups[f] = data_subgroups.get(f, {})

            for g in matched.subgroup.unique():
                g_matched = matched[matched.subgroup == g]
                mse = ((g_matched.pred - g_matched.actual)**2).mean()
                data_subgroups[f][g] = data_subgroups[f].get(g, []) + [mse]

    subgroup_mse = {} # {subgroup: [mse1, mse2, ...], ...}
    for k in data_subgroups.keys():
        for k2 in data_subgroups[k].keys():
            subgroup_mse[k2] = subgroup_mse.get(k2, []) + data_subgroups[k][k2]

    # merge counts with mse
    x = []
    y = []
    z = []
    for k in subgroup_mse.keys():
        # check for nan
        if np.isnan(subgroup_mse[k]).any():
            continue
        x += [subgroup_counts[k]] * len(subgroup_mse[k])
        y += subgroup_mse[k]
        z += [k] * len(subgroup_mse[k])

    # scatter plot with x axis as count and y axis as mse
    plt.figure(figsize=(10, 5))
    ax = sns.scatterplot(x=x, y=y, hue=z)
    # line of best fit
    m, b = np.polyfit(x, y, 1)
    plt.plot(x, m*np.array(x) + b, color='black', linestyle='dotted', label=f'y={m*10000:.2f}e-4x+{b:.2f}', linewidth=2)

    plt.xlabel('Number of Proteins in Subgroup')
    plt.ylabel('MSE')
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.title(f'Subgroup Size vs Test MSE ({model_type} Model)')
    plt.show()
    plt.clf()

    # %%
jyaacoub commented 10 months ago

The training plot looks a little different: Pearson product-moment correlation coefficient: 0.076 image

For test set the correlation was 0.412

jyaacoub commented 10 months ago

Using mmseq2 to generate my own clusters we get basically the same outcome:

DG model:

image Correlation: 0.32254

EDI model:

image Correlation: 0.374

Training looks a bit better: image correlation 0.1168

Code:

# %%
from src.utils.mmseq2 import MMseq2Runner
import pandas as pd 

tsvp = '../data/misc/davis_clustering//tsvs/davisDB_4sens_198clust.tsv'
# read tsv
df = pd.read_csv(tsvp, sep='\t', header=None)
# rename cols
df.columns = ['rep', 'member']

clusters = df.groupby('rep')['member'].apply(list).to_dict()
len(clusters)

# %% group clusters with less than 5 members into one cluster (cluster 0)
clusters_new = {}
for k in clusters.keys():
    if len(clusters[k]) < 5:
        continue
        #clusters_new[0] = clusters_new.get(0, []) + clusters[k]
    else:
        clusters_new[k] = clusters[k]

for k in clusters_new.keys():
    print(len(clusters_new[k]))

clusters = clusters_new

#%% Rename clusters keys to be ints 
clusters = {i: set(clusters[k]) for i, k in enumerate(clusters)}

#%%
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

sns.set_style('darkgrid')

# %%
# get count of every subgroup
# cols are code,SMILE,prot_seq,pkd,prot_id
df_full = pd.read_csv('../data/DavisKibaDataset/davis/nomsa_binary_original_binary/full/XY.csv',
                      index_col='code')

prot_counts = df_full.index.value_counts()

cluster_counts = {} # {cluster: count, ...}
for idx in df_full.index:
    for k in clusters.keys():
        if idx in clusters[k]:
            cluster_counts[k] = cluster_counts.get(k, 0) + 1
            break

# %% get subgroup mse
subset = 'test' #'test'
for model_type in ['EDI']:
    if model_type == 'EDI':
        model_path = lambda x: f'results/model_media/{subset}_set_pred/EDIM_davis{x}D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_{subset}Pred.csv'
    elif model_type == 'DG':
        model_path = lambda x: f'results/model_media/{subset}_set_pred/DGM_davis{x}D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_{subset}Pred.csv'

    data_clust = {} # {cluster: [mse1, mse2, ...], ...}
    for fold in range(5):
        pred = pd.read_csv(model_path(fold), index_col='name')

        for k in clusters.keys():
            # get mse for cluster
            matched = pred[pred.index.isin(clusters[k])]
            if matched.empty:
                continue
            mse = ((matched.pred - matched.actual)**2).mean()

            # add main mse to dict
            data_clust[k] = data_clust.get(k, []) + [mse]

    # merge counts with mse
    x = []
    y = []
    z = []
    for k in data_clust.keys():
        x += [cluster_counts[k]] * len(data_clust[k])
        y += data_clust[k]
        z += [k] * len(data_clust[k])

    # scatter plot with x axis as count and y axis as mse
    plt.figure(figsize=(10, 5))
    ax = sns.scatterplot(x=x, y=y, hue=z)
    # line of best fit
    m, b = np.polyfit(x, y, 1)
    plt.plot(x, m*np.array(x) + b, color='black', linestyle='dotted', label=f'y={m*10000:.2f}e-4x+{b:.2f}', linewidth=2)

    # correlation
    corr = np.corrcoef(x, y)[0, 1]
    print(f'Correlation: {corr}')

    plt.xlabel('Number of Proteins in cluster')
    plt.ylabel('MSE')
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.title(f'Subgroup Size vs {subset} MSE ({model_type} Model)')
    plt.show()
    plt.clf()
jyaacoub commented 9 months ago

Since Kiba doesn't have equal counts for all its proteins we can visualize the overfitting this way instead.

Kiba protein count vs performance:

Test Sets:

Training sets:

Code:

#%%
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

sns.set_style('darkgrid')
dataset = 'kiba'

# cols are code,SMILE,prot_seq,pkd,prot_id
df_full = pd.read_csv(f'../data/DavisKibaDataset/{dataset}/nomsa_binary_original_binary/full/XY.csv',
                      index_col='code')
prot_counts = df_full.index.value_counts()

# %% get subgroup mse
subset = 'train' #'test'

for model_type in ['EDI', 'DG']:
    if model_type == 'EDI':
        model_path = lambda x: f'results/model_media/{subset}_set_pred/EDIM_{dataset}{x}D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_{subset}Pred.csv'
    elif model_type == 'DG':
        if dataset == 'davis':
            model_path = lambda x: f'results/model_media/{subset}_set_pred/DGM_{dataset}{x}D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_{subset}Pred.csv'
        elif dataset == 'kiba':
            model_path = lambda x: f'results/model_media/{subset}_set_pred/DGM_{dataset}{x}D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_{subset}Pred.csv'

    data_clust = {} # {cluster: [mse1, mse2, ...], ...}
    for fold in range(5):
        pred = pd.read_csv(model_path(fold), index_col='name')

        for k in prot_counts.keys():
            matched = pred[pred.index == k]
            if matched.empty: # will be empty if not in this fold
                continue
            mse = ((matched.pred - matched.actual)**2).mean()

            # add main mse to dict
            data_clust[k] = data_clust.get(k, []) + [mse]

    # merge counts with mse
    x = []
    y = []
    z = []
    for k in data_clust.keys():
        x += [prot_counts[k]] * len(data_clust[k])
        y += data_clust[k]
        z += [k] * len(data_clust[k])

    # scatter plot with x axis as count and y axis as mse
    plt.figure(figsize=(10, 5))
    ax = sns.scatterplot(x=x, y=y, hue=z)
    # line of best fit
    m, b = np.polyfit(x, y, 1)
    lines = plt.plot(x, m*np.array(x) + b, color='black', linestyle='dotted', 
                     label=f'y={m*10000:.2f}e-4x+{b:.2f}', linewidth=2)

    # correlation
    corr = np.corrcoef(x, y)[0, 1]
    plt.xlabel('Protein Count in Entire Dataset')
    plt.ylabel('MSE')
    plt.legend(handles=[lines[0]], loc='upper left', title=f'Correlation: {corr:.3f}')
    plt.title(f'Protein Count vs {subset} MSE ({model_type} Model)')
    plt.show()
    plt.clf()
jyaacoub commented 9 months ago

image

So I think we are getting what we expected, there are just a few outliers that are causing problems:

WORST PERFORMERS

### Cluster 26 has 16 proteins and mean mse of 0.581 with std 0.143
['PAK2', 'PAK1', 'YES', 'FYN', 'VRK2', 'SRC', 'FRK', 'BLK', 'HCK', 'MST1', 'LYN', 'TXK', 'MST2', 'PAK3', 'FGR', 'LCK']
{'CK1': {'VRK'}, 'STE': {'STE20'}, 'TK': {'Src', 'Tec'}}

pkd stats (higher is stronger):
min     5.000000
mean    5.627895
max     9.744727
std     0.975937
Name: pkd, dtype: float64

prot seq stats:
min     487.000000
mean    531.687500
max     599.000000
std      29.298102
Name: prot_seq, dtype: float64

### Cluster 38 has 1 proteins and mean mse of 0.913 with std 0.067
['RIPK2']
{'TKL': {'RIPK'}}

pkd stats (higher is stronger):
min     5.000000
mean    5.546609
max     8.337242
std     0.835465
Name: pkd, dtype: float64

prot seq stats:
min     540.0
mean    540.0
max     540.0
std       0.0
Name: prot_seq, dtype: float64

### Cluster 43 has 44 proteins and mean mse of 1.306 with std 0.130
['ULK2', 'EPHA5', 'KIT(V559D)', 'MAP4K2', 'KIT(V559D-T670I)', 'HPK1', 'MERTK', 'FAK', 'EPHA8', 'LOK', 'EPHA7', 'ULK1', 'EPHA4', 'LZK', 'KIT(A829P)', 'EPHA2', 'FLT3(N841I)', 'RIPK5', 'FLT3', 'EPHB6', 'FLT3(K663Q)', 'FLT3(R834Q)', 'EPHA1', 'FLT3(ITD)', 'KIT(V559D-V654A)', 'PYK2', 'EPHB4', 'TAOK3', 'TAOK1', 'FLT3(D835Y)', 'EPHA3', 'DLK', 'EPHB1', 'MAP4K5', 'AAK1', 'MAP4K3', 'KIT(D816V)', 'CSF1R', 'KIT(L576P)', 'NEK9', 'EPHB3', 'FLT3(D835H)', 'KIT(D816H)', 'KIT']
{   'Other': {'NEK', 'NAK', 'ULK'},
    'STE': {'STE20'},
    'TK': {'FAK', 'PDGFR', 'Eph', 'Axl'},
    'TKL': {'MLK', 'RIPK'}}

pkd stats (higher is stronger):
min      5.000000
mean     5.809830
max     10.431798
std      1.184850
Name: pkd, dtype: float64

prot seq stats:
min      820.000000
mean     971.181818
max     1074.000000
std       52.262170
Name: prot_seq, dtype: float64

##################################################
BEST PERFORMERS

### Cluster 7 has 12 proteins and mean mse of 0.141 with std 0.024
['CDK5', 'CDK4-cyclinD3', 'NEK7', 'CDC2L5', 'PIM2', 'STK16', 'NEK6', 'PIM3', 'CDK4-cyclinD1', 'CDK3', 'CDK2', 'PFPK5(Pfalciparum)']
{'CAMK': {'PIM'}, 'CMGC': {'CDK'}, 'Other': {'NEK', 'NAK'}}

pkd stats (higher is stronger):
min     5.000000
mean    5.246409
max     9.292430
std     0.665727
Name: pkd, dtype: float64

prot seq stats:
min     288.000000
mean    305.833333
max     326.000000
std      10.875838
Name: prot_seq, dtype: float64

### Cluster 16 has 13 proteins and mean mse of 0.142 with std 0.072
['BMPR1B', 'TGFBR1', 'CTK', 'SRMS', 'ACVR1', 'BMPR1A', 'FGFR3(G697C)', 'BRK', 'ACVR1B', 'FGFR3', 'CSK', 'ACVRL1', 'BTK']
{'TK': {'FGFR', 'Src', 'Csk', 'Tec'}, 'TKL': {'STKR'}}

pkd stats (higher is stronger):
min     5.000000
mean    5.328586
max     9.070581
std     0.703244
Name: pkd, dtype: float64

prot seq stats:
min     450.000000
mean    491.769231
max     532.000000
std      22.294562
Name: prot_seq, dtype: float64

### Cluster 25 has 21 proteins and mean mse of 0.050 with std 0.018
['MKNK2', 'CAMK4', 'MAPKAPK5', 'MKNK1', 'GSK3B', 'PCTK3', 'CAMK2D', 'NDR1', 'CAMKK1', 'CAMK2A', 'SGK3', 'AKT1', 'CAMK1G', 'ULK3', 'AKT3', 'NDR2', 'NIM1', 'STK33', 'AKT2', 'CHEK1', 'PCTK1']
{   'AGC': {'NDR', 'Akt', 'SGK'},
    'CAMK': {'CAMK-Unique', 'CAMK1', 'CAMKL', 'CAMK2', 'MAPKAPK'},
    'CMGC': {'CDK', 'GSK'},
    'Other': {'CAMKK', 'ULK'}}

pkd stats (higher is stronger):
min      5.000000
mean     5.296790
max     10.408935
std      0.714052
Name: pkd, dtype: float64

prot seq stats:
min     420.000000
mean    471.476190
max     514.000000
std      22.367702
Name: prot_seq, dtype: float64

### Cluster 29 has 13 proteins and mean mse of 0.028 with std 0.017
['PIK3CD', 'PIK3CA(E545A)', 'PIK3CA(M1043I)', 'PIK3CA', 'PIK3CA(E542K)', 'PIK3CA(Q546K)', 'PIK3CA(C420R)', 'PIK3CA(H1047Y)', 'PIK3CB', 'PIK3CA(E545K)', 'PIK3CA(H1047L)', 'PIK3CG', 'PIK3CA(I800L)']
{'Other': {'VPS15'}}

pkd stats (higher is stronger):
min     5.000000
mean    5.173230
max     9.065502
std     0.683815
Name: pkd, dtype: float64

prot seq stats:
min     1044.000000
mean    1069.692308
max     1102.000000
std       11.471384
Name: prot_seq, dtype: float64

Code

# %%
import pprint
from src.utils.mmseq2 import MMseq2Runner
from src.data_analysis.stratify_protein import map_davis_to_kinbase
import pandas as pd 

# tsvp = '../data/misc/davis_clustering//tsvs/davisDB_4sens_198clust.tsv'
# tsvp = '../data/misc/davis_clustering/tsvs/davisDB_4sens_5cov_9c.tsv' # 49 clusters
tsvp = '../data/misc/davis_clustering/tsvs/davisDB_4sens_9c_5cov.tsv'
# read tsv
df = pd.read_csv(tsvp, sep='\t', header=None)
# rename cols
df.columns = ['rep', 'member']

clusters = df.groupby('rep')['member'].apply(list).to_dict()
print(len(clusters), 'clusters')

# %% group clusters with less than 5 members into one cluster (cluster 0)
clusters_new = {}
for k in clusters.keys():
    if len(clusters[k]) == 0: # remove outlier clusters
        continue
        #clusters_new[0] = clusters_new.get(0, []) + clusters[k]
    else:
        clusters_new[k] = clusters[k]

# for k in clusters_new.keys():
#     print(len(clusters_new[k]))

clusters = clusters_new

#%% Rename clusters keys to be ints 
clusters = {i: set(clusters[k]) for i, k in enumerate(clusters)}

#%%
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

sns.set_style('darkgrid')

# %%
# get count of every subgroup
# cols are code,SMILE,prot_seq,pkd,prot_id
df_full = pd.read_csv('../data/DavisKibaDataset/davis/nomsa_binary_original_binary/full/XY.csv',
                      index_col='code')

prot_counts = df_full.index.value_counts()

cluster_counts = {} # {cluster: count, ...}
for idx in df_full.index:
    for k in clusters.keys():
        if idx in clusters[k]:
            cluster_counts[k] = cluster_counts.get(k, 0) + 1
            break

# %% Map davis to kinbase:
kin_df = pd.read_csv('../data/misc/kinase_base_updated.csv', index_col='name')
pred_kb = map_davis_to_kinbase(df_full.index.unique(), kin_df)

# %% get subgroup mse

def get_cluster_details(clust):
    # get subgroups
    subgroups = {}
    for prot in clust:
        fam = pred_kb[prot][1]
        sg = pred_kb[prot][2]

        if fam in subgroups:
            subgroups[fam].add(sg)
        else:
            subgroups[fam] = {sg}
    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(subgroups)
    # print(clust) # protids
    print('\npkd stats (higher is stronger):')
    print(df_full.loc[clust, 'pkd'].describe()[['min', 'mean', 'max', 'std']]) # pkd stats

    # print sequence length stats
    print('\nprot seq stats:')
    print(df_full.loc[clust, 'prot_seq'].str.len().describe()[['min', 'mean', 'max', 'std']])

subset = 'test' #'test'
for model_type in ['EDI']:
    if model_type == 'EDI':
        model_path = lambda x: f'results/model_media/{subset}_set_pred/EDIM_davis{x}D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_{subset}Pred.csv'
    elif model_type == 'DG':
        model_path = lambda x: f'results/model_media/{subset}_set_pred/DGM_davis{x}D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_{subset}Pred.csv'

    data_clust = {} # {cluster: [mse1, mse2, ...], ...}
    for fold in range(5):
        pred = pd.read_csv(model_path(fold), index_col='name')

        for k in clusters.keys():
            # get mse for cluster
            matched = pred[pred.index.isin(clusters[k])]
            if matched.empty:
                continue
            mse = ((matched.pred - matched.actual)**2).mean()

            # add main mse to dict
            data_clust[k] = data_clust.get(k, []) + [mse]

    # merge counts with mse
    x = []
    y = []
    z = []
    for k in data_clust.keys():
        x += [cluster_counts[k]] * len(data_clust[k])
        y += data_clust[k]
        z += [k] * len(data_clust[k])

    # scatter plot with x axis as count and y axis as mse
    plt.figure(figsize=(10, 5))
    ax = sns.scatterplot(x=x, y=y, hue=z)
    # line of best fit
    m, b = np.polyfit(x, y, 1)
    lines = plt.plot(x, m*np.array(x) + b, color='black', linestyle='dotted', 
                     label=f'y={m*10000:.2f}e-4x+{b:.2f}', linewidth=2)

    # correlation
    corr = np.corrcoef(x, y)[0, 1]
    plt.xlabel('Number of Proteins in cluster')
    plt.ylabel('MSE')
    plt.legend(handles=[lines[0]], loc='upper left', title=f'Correlation: {corr:.3f}')
    plt.title(f'Subgroup Size vs {subset} MSE ({model_type} Model)')
    plt.show()
    plt.clf()

    # print worse performers
    print('WORST PERFORMERS')
    for k in data_clust.keys():
        if np.mean(data_clust[k]) > 0.5:
            clust = list(clusters[k])
            print(f'\n\n### Cluster {k} has {len(clust)} proteins and mean mse of {np.mean(data_clust[k]):.3f} '
                    f'with std {np.std(data_clust[k]):.3f}')
            get_cluster_details(clust)

    # print best performers
    print('')
    print('#'*50)
    print('BEST PERFORMERS')
    for k in data_clust.keys():
        if np.mean(data_clust[k]) < 0.15:
            clust = list(clusters[k])
            print(f'\n\n### Cluster {k} has {len(clust)} proteins and mean mse of {np.mean(data_clust[k]):.3f} '
                    f'with std {np.std(data_clust[k]):.3f}')
            get_cluster_details(clust)
jyaacoub commented 9 months ago

Plotting MSE vs MSE of DG and EDI on davis dataset shows which clusters perform better with EDI than with DG.

RAW MSE for each cluster (5 per cluster for each of the 5 training folds)

image

Mean MSE for each cluster (mean of the 5 folds)

image

Clusters below the line indicate better performance in DG, and clusters above the line perform better with EDI. As we expect, the majority of the clusters are above the line, which agrees with our observation that EDI performs better.

CODE:

# %%
from src.utils.mmseq2 import MMseq2Runner
from src.data_analysis.stratify_protein import map_davis_to_kinbase
import pandas as pd 
import os

## DATASET PARAMETERS
DATASET = 'pdbbind'
DATASET = 'kiba'
DATASET = 'davis'

## CLUSTERING PARAMETERS
CLUST_OPTION = '4sens_9c_5cov'

## PLOTTING PARAMETERS
GET_MEAN = True

## Processing for getting correct paths based on dataset:
if DATASET == 'davis':
    XY_csv_path = f'../data/DavisKibaDataset/{DATASET}/'
elif DATASET == 'pdbbind':
    XY_csv_path = f'../data/PDBbindDataset/'
else:
    raise ValueError('Invalid dataset')

XY_csv_path += '/nomsa_binary_original_binary/full/XY.csv'
clust_dir = f'../data/misc/{DATASET}_clustering'

if DATASET == 'davis':
    # tsvp = f'{clust_dir}//tsvs/davisDB_4sens_198clust.tsv'
    # tsvp = f'{clust_dir}/tsvs/davisDB_4sens_5cov_9c.tsv' # 49 clusters
    tsvp = f'{clust_dir}/tsvs/{DATASET}DB_4sens_9c_5cov.tsv' # 198 clusters
elif DATASET == 'kiba':
    tsvp = f'{clust_dir}/tsvs/{DATASET}DB_4sens_9c_5cov.tsv'
elif DATASET == 'pdbbind':
    tsvp = f'{clust_dir}/tsvs/{DATASET}DB_4sens_9c_5cov.tsv'
else:
    raise ValueError('Invalid dataset')

assert os.path.exists(tsvp), f"TSV file {os.path.basename(tsvp)} does not exist. Please run clustering first."

# %%

# read tsv
df = pd.read_csv(tsvp, sep='\t', header=None)
# rename cols
df.columns = ['rep', 'member']

clusters = df.groupby('rep')['member'].apply(list).to_dict()
print(len(clusters), 'clusters')

# %% group clusters with less than 5 members into one cluster (cluster 0)
clusters_new = {}
for k in clusters.keys():
    if len(clusters[k]) == 0: # remove outlier clusters
        continue
        #clusters_new[0] = clusters_new.get(0, []) + clusters[k]
    else:
        clusters_new[k] = clusters[k]

clusters = clusters_new

#%% Rename clusters keys to be ints 
clusters = {i: set(clusters[k]) for i, k in enumerate(clusters)}

#%%
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

sns.set_style('darkgrid')

# %%
# get count of every subgroup
# cols are code,SMILE,prot_seq,pkd,prot_id
df_full = pd.read_csv(XY_csv_path,
                      index_col='code')

prot_counts = df_full.index.value_counts()

cluster_counts = {} # {cluster: count, ...}
for idx in df_full.index:
    for k in clusters.keys():
        if idx in clusters[k]:
            cluster_counts[k] = cluster_counts.get(k, 0) + 1
            break

# %% get subgroup mse
subset = 'test' #'test'
# plot mse vs mse for EDI and DG
model_cluster_mse = {} # {model_type: {cluster: mse, ...}, ...}
for model_type in ['EDI', 'DG']:
    d = DATASET if DATASET != 'pdbbind' else 'PDBbind'
    if model_type == 'EDI':
        model_path = lambda x: f'results/model_media/{subset}_set_pred/EDIM_{d}{x}D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_{subset}Pred.csv'
    elif model_type == 'DG':
        model_path = lambda x: f'results/model_media/{subset}_set_pred/DGM_{d}{x}D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_{subset}Pred.csv'

    data_clust = {} # {cluster: [mse1, mse2, ...], ...}
    for fold in range(5):
        pred = pd.read_csv(model_path(fold), index_col='name')

        for k in clusters.keys():
            # get mse for cluster
            matched = pred[pred.index.isin(clusters[k])]
            if matched.empty:
                continue
            mse = ((matched.pred - matched.actual)**2).mean()

            # add main mse to dict
            data_clust[k] = data_clust.get(k, []) + [mse]

    # merge counts with mse
    model_cluster_mse[model_type] = data_clust.copy()

# %% plot mse vs mse for EDI and DG
# verify that clusters are present in both models
# if not, remove from both
for k in list(model_cluster_mse['EDI'].keys()):
    assert k in model_cluster_mse['DG'], "cluster not in both models, this is likely due to mismatched subsets being used."

x,y,z = [], [], []
for k in model_cluster_mse['EDI'].keys():
    assert len(model_cluster_mse['EDI'][k]) == len(model_cluster_mse['DG'][k]), "Cluster size mismatch. Are you using the same number of folds for both models?"
    if GET_MEAN:
        x.append(np.mean(model_cluster_mse['EDI'][k]))
        y.append(np.mean(model_cluster_mse['DG'][k]))
        z.append(k)
    else:
        x += model_cluster_mse['EDI'][k]
        y += model_cluster_mse['DG'][k]
        z += [k]*len(model_cluster_mse['EDI'][k])

sns.set_style('darkgrid')

plt.figure(figsize=(10,5))
sns.scatterplot(x=x, y=y, hue=z, palette='tab20')

# plot line at y=x for reference
plt.plot([0,max(x)], [0,max(x)], color='black', linestyle='--')

plt.xlabel('EDI MSE')
plt.ylabel('DG MSE')
plt.title(f'MSE of clusters for EDI and DG models ({subset} set)')
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.show()
plt.clf()    

# %%
jyaacoub commented 8 months ago

This has gone stale with not really much else to add. closing...