tidymodels / yardstick

Tidy methods for measuring model performance
https://yardstick.tidymodels.org/
Other
365 stars 54 forks source link

New metric request: Survival PRC-AUC function #496

Open asb2111 opened 3 months ago

asb2111 commented 3 months ago

In situations with major class imbalance, ROC-AUC may not be a good metric to assess model concordance. Instead, as suggested in numerous places such as the scikit-learn documentation, the area under the precision-recall curve may be preferred. Functions already exist for PRC-AUC for the standard settings, but there is currently no function available in yardstick for the survival setting.

I've attached some code here that is an adaptation of the roc_auc_survival_vec function and the functions it depends on that, I believe, implements the survival version of PRC-AUC by using the principles of Vock et al., where they provide a general recipe for incorporating inverse probability of censoring weights to any model. The final step, after estimating the weights, is:

Apply an existing prediction method to a weighted version of the training set where each member i of the training set is weighted by a factor of $\omega_i$. In other words, if $\omega_i=3$ it is as if the observation appeared three times in the data set.

# PRC ####

prc_auc_survival_vec <- function(truth,
         estimate,
         na_rm = TRUE,
         case_weights = NULL,
         ...) {
  # No checking since prc_curve_survival_vec() does checking
  curve <- prc_curve_survival_vec(
    truth = truth,
    estimate = estimate,
    na_rm = na_rm,
    case_weights = case_weights
  )

  curve %>%
    dplyr::group_by(.eval_time) %>%
    dplyr::summarize(.estimate = prc_trap_auc(pr, re))
}

prc_curve_survival_vec <- function(truth,
                                   estimate,
                                   na_rm = TRUE,
                                   case_weights = NULL,
                                   ...) {
  yardstick::check_dynamic_survival_metric(truth, estimate, case_weights)

  if (na_rm) {
    result <- yardstick_remove_missing(truth, seq_along(estimate), case_weights)

    truth <- result$truth
    estimate <- estimate[result$estimate]
    case_weights <- result$case_weights
  } else if (yardstick::yardstick_any_missing(truth, estimate, case_weights)) {
    cli::cli_abort(
      c(x = "Missing values were detected and {.code na_ra = FALSE}.",
        i = "Not able to perform calculations.")
    )
  }

  prc_curve_survival_impl(truth = truth,
                          estimate = estimate,
                          case_weights = case_weights)
}

prc_curve_survival_impl <- function(truth,
                                    estimate,
                                    case_weights) {
  event_time <- .extract_surv_time(truth)
  delta <- .extract_surv_status(truth)
  case_weights <- vctrs::vec_cast(case_weights, double())
  if (is.null(case_weights)) {
    case_weights <- rep(1, length(delta))
  }

  # Drop any `0` weights.
  # These shouldn't affect the result, but can result in wrong thresholds
  detect_zero_weight <- case_weights == 0
  if (any(detect_zero_weight)) {
    detect_non_zero_weight <- !detect_zero_weight
    event_time <- event_time[detect_non_zero_weight]
    delta <- delta[detect_non_zero_weight]
    case_weights <- case_weights[detect_non_zero_weight]
    estimate <- estimate[detect_non_zero_weight]
  }

  data <- dplyr::tibble(event_time, delta, case_weights, estimate)
  data <- tidyr::unnest(data, cols = estimate)

  .eval_times <- unique(data$.eval_time)

  not_missing_pred_survival <- !is.na(data$.pred_survival)

  out <- list()
  for (i in seq_along(.eval_times)) {
    .eval_time_ind <- .eval_times[[i]] == data$.eval_time & not_missing_pred_survival

    res <- prc_curve_survival_impl_one(
      data$event_time[.eval_time_ind],
      data$delta[.eval_time_ind],
      data[.eval_time_ind, ],
      data$case_weights[.eval_time_ind]
    )

    res$.eval_time <- .eval_times[[i]]
    out[[i]] <- res
  }

  dplyr::bind_rows(out)
}

prc_curve_survival_impl_one <- function(event_time, delta, data, case_weights) {
  res <- dplyr::tibble(.threshold = sort(unique(c(-Inf, data$.pred_survival, Inf)), decreasing = TRUE))

  obs_time_le_time <- event_time <= data$.eval_time
  obs_time_gt_time <- event_time > data$.eval_time
  n <- nrow(data)

  re_denom <- sum(obs_time_le_time * delta * data$.weight_censored * case_weights, na.rm = TRUE)

  data_df <- data.frame(
    le_time = obs_time_le_time,
    ge_time = obs_time_gt_time,
    delta = data$delta,
    weight_censored = data$.weight_censored,
    case_weights = case_weights
  )

  data_split <- vctrs::vec_split(data_df, data$.pred_survival)
  data_split <- data_split$val[order(data_split$key)]

  re <- vapply(
    data_split,
    function(x) sum(x$le_time * x$delta * x$weight_censored * x$case_weights, na.rm = TRUE),
    FUN.VALUE = numeric(1)
  )

  re <- cumsum(re)
  re <- re / re_denom
  re <- dplyr::if_else(re > 1, 1, re)
  re <- dplyr::if_else(re < 0, 0, re)
  re <- c(0, re, 1)
  res$re <- re

  pr_num <- vapply(
    data_split,
    function(x) sum(x$le_time * x$delta * x$weight_censored * x$case_weights, na.rm = TRUE),
    FUN.VALUE = numeric(1)
  )

  pr_den <- vapply(
    data_split,
    function(x) sum(x$case_weights * x$weight_censored, na.rm = TRUE),
    FUN.VALUE = numeric(1)
  )

  pr_den <- cumsum(pr_den)
  pr_num <- cumsum(pr_num)
  pr <- pr_num / pr_den
  pr <- dplyr::if_else(pr > 1, 1, pr)
  pr <- dplyr::if_else(pr < 0 | is.na(pr), 0, pr)
  pr <- c(min(pr, na.rm = T), pr, max(pr, na.rm = T))
  res$pr <- pr

  res
}

prc_trap_auc <- function(pr, re) {
  not_na <- !is.na(pr) & !is.na(re)
  pr <- pr[not_na]
  re <- re[not_na]

  yardstick:::auc(re, pr)
}

prc_curve_survival <- function(data, ...){
  UseMethod("prc_curve_survival")
}

prc_curve_survival.data.frame <- function(data,
                                          truth,
                                          ...,
                                          na_rm = TRUE,
                                          case_weights = NULL){

  result <- curve_survival_metric_summarizer(
    name = "prc_curve_survival",
    fn = prc_curve_survival_vec,
    data = data,
    truth = !!enquo(truth),
    ...,
    na_rm = na_rm,
    case_weights = !!enquo(case_weights)
  )

  yardstick:::curve_finalize(result, data, "prc_survival_df", "grouped_prc_survival_df")
}

autoplot.prc_survival_df <- function(object, ...) {
  `%+%` <- ggplot2::`%+%`
  object$.eval_time <- format(object$.eval_time)

  # Base chart
  prc_chart <- ggplot2::ggplot(data = object)

  # create aesthetic
  prc_aes <- ggplot2::aes(
    x = re,
    y = pr,
    color = .eval_time,
    group = .eval_time
  )

  # build the graph
  prc_chart <- prc_chart %+%
    ggplot2::geom_step(mapping = prc_aes, direction = "hv") %+%
    # ggplot2::geom_abline(lty = 3) %+%
    ggplot2::coord_equal() %+%
    ggplot2::theme_bw() %+%
    ggplot2::xlab("Recall") %+%
    ggplot2::ylab("Precision")

  prc_chart
}

prc_auc_survival <- function(data, ...){
  UseMethod("prc_auc_survival")
}

prc_auc_survival <- yardstick::new_dynamic_survival_metric(prc_auc_survival, direction = "maximize")

prc_auc_survival.data.frame <- function(data,
                                      truth,
                                      ...,
                                      na_rm = TRUE,
                                      case_weights = NULL) {
  yardstick::dynamic_survival_metric_summarizer(
    name = "prc_auc_survival",
    fn = prc_auc_survival_vec,
    data = data,
    truth = !!enquo(truth),
    ...,
    na_rm = na_rm,
    case_weights = !!enquo(case_weights)
  )
}
EmilHvitfeldt commented 3 months ago

Thank you for the suggestion!