angelolab / Nimbus

Other
12 stars 1 forks source link

Evaluate models / calculate scores for the pixel level #19

Closed JLrumberger closed 1 year ago

JLrumberger commented 2 years ago

Instructions

Write code to do predictions given a model checkpoint and calculate performance metrics for the predictions.

Relevant background

We trained a model on the TNBC tfrecord dataset and need to calculate pixel-wise and cell-wise evaluation metrics. This issue concerns pixel-wise metrics, but they should be implemented such that metric calculation also works for cell-wise data to re-use code.

Design overview

  1. Write prediction routine that loads model weights, takes images and returns predictions. This should also include tile & stitch functionality.
  2. Calculate pixel-level ROC curves. Find best threshold based on ROC and report confusion matrix, accuracy, precision and recall for all pixels. Report these metrics for each marker individually. Here we assume that we have a dict for each predicted sample with the keys ["instance_mask", "imaging_platform", "marker_activity_mask", "dataset", "marker", "activity_df", "prediction"]. We iterate through the sample dicts and calculate ROC curves for each single one and store the ROC values in a list. Then we can average the ROC curves and plot it together with it's standard deviation. From the mean ROC curve we get the thresholds. Iterate through the samples again and calculate confusion matrix, accuracy, precision and recall for each sample and store it in a dataframe.
  3. Write a small script that loads tfrecord validation data, uses prediction routine from 1. and calculates metrics from 2.

Code mockup 1.

class Pred
  def init(params)
    self.model = load_model()
  def predict(self, x):
    if img > params["input_shape]:
      y_hat = self.tile_and_stitch(x)
    else:
      y_hat = self.model(x)
    return y_hat
  def tile_and_stitch(x)
    ...
    return y_hat

2.

def plot_avg_roc_from_list(list)
    return plot, mean_roc

def calc_metrics(y_hat, y_true, threshold)
    y_hat_binary = (y_hat > threshold).astype(int)
    sample_df = {
    tp,fp,tn,fn = confusion_matrix(y_hat, y_true)
    precision = tp/(tp+fp)
    recall = tp/(tp+fn)
    f_score = tp/(tp+0.5*(fp+fn))
    accuracy = (tp+tn)/(tp+fp+tn+fn)
    return dict

3.

# load tfrecord
pred = Pred(params)
roc_list = []
for sample in record:
    y_hat = pred(sample["mplex_img"])
    save(y_hat)
    roc_list.append(sklearn.ROC(y_hat, y_true))

mean_roc = plot_avg_roc_from_list(roc_list)
threshold = argmax(mean_roc['true_positive_rate'] - mean_roc['false_positive_rate])/100

metric_list = []
for sample in folder:
    dict = calc_metrics(sample['y_hat'], sample['y_true'])
    metric_list.append(dict)

df = pd.DataFrame(metric_list)
df.to_csv()

Required inputs

  1. takes dict params, model needs to be stored under path given by params['model_path']
  2. take a list of dicts, each containing a predicted probabilities and the groundtruth
  3. script that requires params and data

Output files

  1. predictions, average over ROC curves, dict with metrics for each image, plot image

Timeline Give a rough estimate for how long you think the project will take. In general, it's better to be too conservative rather than too optimistic.

Estimated date when a fully implemented version will be ready for review: 19.09.

Estimated date when the finalized project will be merged in: 20.09