mlr-org / mlr

Machine Learning in R
https://mlr.mlr-org.com
Other
1.64k stars 405 forks source link

Precision at given percent recall measure #1575

Closed danielhomola closed 4 years ago

danielhomola commented 7 years ago

Dear mlr devs,

Since I'm working with imbalanced datasets, I had to create a custom measure that calculates precision at a predefined level of recall. It's pretty useful if you want to optimise performance during the hyper-parameter search at a given point of the PR curve.

I don't really have the time for a proper pull request so I'll just copy the code here, and if you find it useful, please feel free to add it to the package.


make_custom_pr_measure <- function(recall_perc=5, name_str="pr5"){

  find_prec_at_recall <- function(pred, recall_perc=5){
    # This function takes in a prediction output from a trained mlR learner. It
    # extracts the predicitons, and finds the highest precision at a given
    # percentage of recall. 

    # see docs here: https://cran.r-project.org/web/packages/PRROC/PRROC.pdf
    # the positive class has to be 1, and the negative has to be 0.
    positive_class <- pred$task.desc$positive
    prob = getPredictionProbabilities(pred, cl=positive_class)

    # get truth and turn it into ones (positive) and zeros (negative)
    truth <- getPredictionTruth(pred)
    if (is.factor(truth)) {
      pos_ix <- as.integer(truth) == which(levels(truth) == positive_class)
    } else {
      pos_ix <- truth == positive_class
    }
    truth <- as.integer(pos_ix)

    # Create desc sorted table of probs and truth
    df <- data.frame(truth, prob)
    df_pos <- df[which(df$truth == 1),]
    df_pos <- BBmisc::sortByCol(df_pos, "prob", asc=F)
    pos_N <- nrow(df_pos)

    # Find the right threshold for x% recall, by walking through the probs in
    # the df_pos table and using each as a thrsh to calculate recall
    recall_tmp <- 0
    thrsh_tmp <- 0
    ix <- 1
    recall_target <- recall_perc/100
    while (recall_tmp < recall_target){
      # To make sure we pick the thrsh_tmp that leads us the closest to the 
      # desired recall level
      recall_tmp2 <- recall_tmp
      thrsh_tmp2 <- thrsh_tmp
      # Threshold we'll try
      thrsh_tmp <- df_pos$prob[ix]
      # Predictions that this threshold translates to
      pred_tmp <- as.numeric(df$prob >= thrsh_tmp)
      # Calculate true positive rate = recall
      recall_tmp <- sum(pred_tmp)/pos_N
      ix <- ix + 1
    }

    # Two closest recall levels
    recall_tmps <- c(recall_tmp, recall_tmp2)
    # Two corresponding thresholds
    thrsh_tmps <- c(thrsh_tmp, thrsh_tmp2)
    recall_diff <- abs(recall_tmps - recall_target)
    # Threshold to use in precision calculation
    thrsh <- thrsh_tmps[which(recall_diff == min(recall_diff))]
    # Find precision at this threshold
    tp <- sum(df$truth[df$prob >= thrsh])
    pred_n <- sum(df$prob >= thrsh)
    tp/pred_n
  }

  name <- paste("Precision at ", as.character(recall_perc),"%"," recall", sep='')

  custom_measure = makeMeasure(
    id = name_str, 
    name = name,
    properties = c("classif", "req.prob", "req.truth"),
    minimize = FALSE, best = 1, worst = 0,
    extra.args = list("threshold" = recall_perc),
    fun = function(task, model, pred, feats, extra.args) {
      find_prec_at_recall(pred, extra.args$threshold)
    }
  )
  custom_measure
}
stale[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.