jyaacoub / MutDTA

Improving the precision oncology pipeline by providing binding affinity purtubations predictions on a pirori identified cancer driver genes.
1 stars 2 forks source link

Try structural aware vocabulary (SaProt) #53

Closed jyaacoub closed 8 months ago

jyaacoub commented 11 months ago

HuggingFace Model - https://huggingface.co/westlake-repl

SaProt paper - https://www.biorxiv.org/content/10.1101/2023.10.01.560349v1

jyaacoub commented 10 months ago

github - https://github.com/westlake-repl/SaProt

jyaacoub commented 10 months ago

Preliminary results with Davis without properly tuned hyperparams:

Cindex

image

output values:

[22    0.837200
 23    0.825905
 24    0.842545
 25    0.821925
 26    0.811849
 Name: cindex, dtype: float64, ##nomsa
 27    0.829333
 28    0.821276
 29    0.826504
 30    0.840412
 31    0.823426
 Name: cindex, dtype: float64, ##msa
 53    0.855447
 54    0.843247
 55    0.855414
 56    0.842134
 57    0.842102
 Name: cindex, dtype: float64, ##esm
 129    0.852516
 130    0.848126
 131    0.858061
 132    0.861351
 Name: cindex, dtype: float64] ##foldseek

MSE plot

image

[22    0.403854
 23    0.477717
 24    0.415949
 25    0.434732
 26    0.429358
 Name: mse, dtype: float64,
 27    0.449363
 28    0.471800
 29    0.453965
 30    0.419492
 31    0.460131
 Name: mse, dtype: float64,
 53    0.380870
 54    0.343248
 55    0.355290
 56    0.365248
 57    0.375826
 Name: mse, dtype: float64, ##ESM
 129    0.370823
 130    0.364358
 131    0.363241
 132    0.360638
 Name: mse, dtype: float64] ##FOLDSEEK

Code:

# Figure 4: violin plot with error bars for Cross-validation results to show significance among pro feats
def fig4_pro_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex', exclude=['shannon'], 
                         show=False, add_labels=True, add_stats=True, ax=None):
    # Extract relevant data
    filtered_df = df[(df['edge'] == 'binary') & (~df['overlap']) & (df['lig_feat'].isna())]

    # show all with fold info
    filtered_df = filtered_df[(filtered_df['data'] == sel_dataset) & (filtered_df['fold'] != '')]
    nomsa = filtered_df[(filtered_df['feat'] == 'nomsa')][sel_col]
    msa = filtered_df[(filtered_df['feat'] == 'msa')][sel_col]
    # shannon = filtered_df[(filtered_df['feat'] == 'shannon')][sel_col]
    esm = filtered_df[(filtered_df['feat'] == 'ESM')][sel_col]
    foldseek = filtered_df[(filtered_df['feat'] == 'foldseek')][sel_col]

    # printing length of each feature
    if verbose:
        print(f'nomsa: {len(nomsa)}')
        print(f'msa: {len(msa)}')
        # print(f'shannon: {len(shannon)}')
        print(f'esm: {len(esm)}')
        print(f'foldseek: {len(foldseek)}')

    # Get values for each node feature
    plot_data = [nomsa, msa, esm, foldseek]
    ax = sns.violinplot(data=plot_data, ax=ax)
    ax.set_xticklabels(['nomsa', 'msa', 'esm', 'foldseek'])
    ax.set_ylabel(sel_col)
    ax.set_xlabel('Features')
    ax.set_title(f'Feature {sel_col} for {sel_dataset}')

    # Annotation for stats
    if add_stats:
        pairs = [(0,1), (0,2), (1,2)]
        if len(foldseek) > 0: 
            pairs += [(0,3),(1,3), (2,3)] # add foldseek pairs if foldseek is not empty
        annotator = Annotator(ax, pairs, data=plot_data, verbose=verbose)
        annotator.configure(test='Mann-Whitney', text_format='star', loc='inside', 
                            hide_non_significant=not verbose)
        annotator.apply_and_annotate()

    if show:
        plt.show()

    return plot_data
jyaacoub commented 4 months ago

Conclusions:

image

SaProt only performs marginally better than ESM but the biggest downside is that we are still dependent on these large language models that drastically increase compute costs for minimal performance improvements.

To get around this our focus has shifted to using GVP architectures that are also "structure-aware" in the same way that the foldseek tokens enable "structure awareness" for SaProt.