rs-station / careless

Merge X-ray diffraction data with Wilson's priors, variational inference, and metadata
MIT License
16 stars 6 forks source link

Need for an example analysis of cross validation metrics #49

Closed DHekstra closed 1 month ago

DHekstra commented 2 years ago

The Careless paper describes CCpred and CChalf crossvalidations statistics, but neither the careless nor careless-examples repository illustrates how to perform such analysis.

DHekstra commented 2 years ago

I can help work on this. I think a single Jupyter notebook will suffice that shows how to analyze a single careless run, run with half-dataset repeats and a test set. Or, perhaps it is most natural to do this for the S-SAD dataset as a function of Student t degrees of freedom.

kmdalton commented 2 years ago

One thought: we should make sure the notebook doesn't fail in the special case of a single repeat.

kmdalton commented 2 years ago

Here's a simple script I put together to analyze xval files. Perhaps take this as some inspiration.

import numpy as np
import pandas as pd
import reciprocalspaceship as rs
import seaborn as sns
from argparse import ArgumentParser
import seaborn as sns
from matplotlib import pyplot as plt

desc = """
This script will generate half dataset correlations from the `*_xval_#.mtz` files output by careless. 
It will print the estimated CChalf values along with standard errors if repeats are present.
Optionally, CChalfs can be written to a csv file and/or saved as a plot in png format. 
"""

parser = ArgumentParser(description=desc)
parser.add_argument("xval_mtz", type=str, help="One *_xval_#.mtz file generated by careless.")
parser.add_argument("--num-bins", "-n", type=int, default=20, help="Number of resolution bins with default 20.")
parser.add_argument("--f-key", '-f', type=str, default='F', help="Column name to use for correlations. The default is F.")
parser.add_argument("--repeat-key", '-r', type=str, default='repeat', help="Column containing identifiers for crossvalidation repeats. This defaults to 'repeat'")
parser.add_argument("--correlation-method", '-m', type=str, default='spearman', help="Which correlation coefficient to use. This can be either 'spearman' or 'pearson'. The default is the robust 'spearman' estimator.")
parser.add_argument("--output", "-o", type=str, default=None, help="Optionally, write correlations to this text file in csv format")
parser.add_argument("--png", "-p", type=str, default=None, help="Optionally, save the resolution dependent correlations to this png file. Note the error estimates in this image are the default bootstrapped confidence intervals from seaborn (sns.lineplot).")
parser.add_argument("--dpi", type=int, default=300, help="DPI of the png file with default 300.")
parser.add_argument("--embed", action='store_true', help="Drop to an IPython shell to inspect variables after running")
parser = parser.parse_args()
xval_mtz = parser.xval_mtz
nbins = parser.num_bins

f_key = parser.f_key
repeat_key = parser.repeat_key
correlation_method = parser.correlation_method.lower()

ds = rs.read_mtz(xval_mtz)

ds,labels = ds.compute_dHKL().assign_resolution_bins(bins=nbins)

if repeat_key not in ds:
    ds[repeat_key] = 0

edges = np.concatenate([
    ds.groupby('bin').dHKL.max().to_numpy(),
    [ds[ds.bin==nbins-1].dHKL.min()]
])

label_ds = rs.DataSet({
    'bin': np.arange(nbins),
    'Resolution': labels
}).set_index('bin')

ds = ds.reset_index().set_index([repeat_key, 'bin', 'H', 'K', 'L'])[[f_key, 'half']]

ds = ds.loc[ds.half==0, [f_key]].join(ds.loc[ds.half==1, [f_key]], rsuffix='2')
cc = ds.groupby([repeat_key, "bin"]).corr(method=correlation_method).xs(f_key+'2', level=2)[[f_key]]
cc_overall = ds.groupby(repeat_key).corr(method=correlation_method).xs(f_key+'2', level=1)[[f_key]]
cc = cc.rename(columns={f_key: 'CChalf'})
cc_overall = cc_overall.rename(columns={f_key: 'CChalf'})

sns.lineplot(data=cc, x='bin', y='CChalf', color='k')
plt.xticks(
    np.arange(nbins+1) - 0.5,
    [f'{i:0.2f}' for i in edges],
    rotation_mode='anchor',
    rotation=45,
    ha='right',
)
plt.xlabel("Resolution ($\AA$)")
if correlation_method == 'spearman':
    plt.ylabel("Spearman $CC_{1/2}$")
elif correlation_method == 'pearson':
    plt.ylabel("Pearson $CC_{1/2}$")
plt.tight_layout()
plt.grid(linestyle='dashdot')

if parser.png is not None:
    plt.savefig(parser.png, dpi=parser.dpi)

cc = cc.reset_index()
cc['Resolution'] = label_ds.loc[cc['bin']].to_numpy()
del(cc['bin'])
cc_overall['Resolution'] = 'Overall'
cc = rs.concat((cc, cc_overall.reset_index()), check_isomorphous=False)

if parser.output is not None:
    cc.to_csv(parser.output)

summary = cc.loc[:,['Resolution', 'CChalf']].groupby('Resolution').apply(lambda x: pd.DataFrame({'Mean': x.mean(numeric_only=True), 'Std Error': x.std(numeric_only=True)}))
print(summary)

if parser.embed:
    from IPython import embed
    embed(colors='linux')
kmdalton commented 1 month ago

I would direct folks to this biorxiv preprint which contains detailed descriptions of the crossvalidation scheme in careless.

DHekstra commented 1 month ago

Could/should we do so on the front page README?

kmdalton commented 1 month ago

Done, @DHekstra