tidymodels / yardstick

Tidy methods for measuring model performance
https://yardstick.tidymodels.org/
Other
367 stars 54 forks source link

roc_auc_survival appears to be using the wrong weights #495

Closed asb2111 closed 5 months ago

asb2111 commented 6 months ago

I apologize in advance if I am reading the function wrong.

In the function roc_curve_survival_impl_one, the sensitivity appears to be computed as: $(\sum{\frac{\Delta I(T\le t)}{n w(t)}})^{-1} \sum{\frac{\Delta I(T\le t) I(Pred = 1)}{n w(t)})}=(\sum{\Delta I(T\le t) P(C>t)})^{-1} \sum{\Delta I(T\le t) I(Pred = 1)P(C>t)}$,
where $I(x)$ is an indicator taking value 1 if x is true, $w(t)$ is the weight for the subject at time t, and Pred is 1 if the subject's predicted value exceeds the threshold value.

For sensitivity, the denominator should be the IPC weighted number of positives, and numerator should be the IPC weighted number positives labeled as positive. But here, instead of applying the weights to the numerator and the denominator, the actual probabilities of remaining uncensored are being applied.

I believe the correction should be changing: multiplier <- delta / (n * data$.weight_censored)

to: multiplier <- delta * data$.weight_censored

asb2111 commented 6 months ago

Similarly, it appears the specificity is not using the weights at all.

asb2111 commented 6 months ago

A proposed change to the function:

roc_curve_survival_impl_one <- function(event_time, delta, data, case_weights) {
  res <- dplyr::tibble(.threshold = sort(unique(c(-Inf, data$.pred_survival, Inf))), decreasing = TRUE)

  obs_time_le_time <- event_time <= data$.eval_time
  obs_time_gt_time <- event_time > data$.eval_time
  n <- nrow(data)

  sensitivity_denom <- sum(obs_time_le_time * delta * data$.weight_censored * case_weights, na.rm = TRUE)
  specificity_denom <- sum(obs_time_gt_time * data$.weight_censored * case_weights, na.rm = TRUE)

  data_df <- data.frame(
    le_time = obs_time_le_time,
    ge_time = obs_time_gt_time,
    delta = delta,
    weight_censored = data$.weight_censored,
    case_weights = case_weights
  )

  data_split <- vec_split(data_df, data$.pred_survival)
  data_split <- data_split$val[order(data_split$key)]

  sensitivity <- vapply(
    data_split,
    function(x) sum(x$le_time * x$delta * x$weight_censored * x$case_weights, na.rm = TRUE),
    FUN.VALUE = numeric(1)
  )

  sensitivity <- cumsum(sensitivity)
  sensitivity <- sensitivity / sensitivity_denom
  sensitivity <- dplyr::if_else(sensitivity > 1, 1, sensitivity)
  sensitivity <- dplyr::if_else(sensitivity < 0, 0, sensitivity)
  sensitivity <- c(0, sensitivity, 1)
  res$sensitivity <- sensitivity

  specificity <- vapply(
    data_split,
    function(x) sum(x$ge_time * x$weight_censored * x$case_weights, na.rm = TRUE),
    FUN.VALUE = numeric(1)
  )
  specificity <- cumsum(specificity)
  specificity <- specificity / specificity_denom
  specificity <- dplyr::if_else(specificity > 1, 1, specificity)
  specificity <- dplyr::if_else(specificity < 0, 0, specificity)
  specificity <- c(0, specificity, 1)
  specificity <- 1 - specificity
  res$specificity <- specificity

  res
}
EmilHvitfeldt commented 6 months ago

Hello @asb2111 👋 Thanks for your interest and looking into this issue!

do you mind referencing a paper/website to support this change?

Thank you!

asb2111 commented 6 months ago

Hi! Thanks for taking a look at this. Again, just to reiterate, I may be misreading the code or not understanding the implementation so if my comment doesn't make sense, I do apologize.

In Blanche et al., section 3.3 describes the inverse probability of censoring weighting estimator, and you can see in the numerator and denominator, each term is divided by $\hat{S}_C(T^_i)$, the probability of remaining uncensored. The weights are dropped in the specificity because they are all equal by definition, but in my implementation I put them in anyway in case a future version wants to allow a conditional inverse probability of censoring weighted estimator such as in equation 4.3 of this paper. If you want to do the simplest fix to the problem, you could just replace: `multiplier <- delta / (n data$.weight_censored)withmultiplier <- delta / (n * data$.pred_censored)`. I can confirm that this small change gives the same result as my code above.

A general discussion of using IPCW for prediction problems is found in Vock et al., where they provide a general recipe for incorporating inverse probability of censoring weights to any model. The final step, after estimating the weights, is:

Apply an existing prediction method to a weighted version of the training set where each member i of the training set is weighted by a factor of $\omega_i$. In other words, if $\omega_i=3$ it is as if the observation appeared three times in the data set.

We can see this being implemented in the Illustration: Confusion Matrix section of the survival-metrics-details vignette on the tidymodels website, the confusion matrix is constructed with:

binary_encoding %>%
  filter(.eval_time == 1.00) %>%
  conf_mat(truth = obs_class,
           estimate = pred_class,
           case_weights = .weight_censored)

Following the rabbit hole down for conf_hat, we get to hardhat::weighted_table, which ultimately gets us to:

  tapply(
    X = weights,
    INDEX = args,
    FUN = sum,
    na.rm = na_remove,
    default = 0,
    simplify = TRUE
  )

So here it looks like it is summing the weights to get the elements of the confusion matrix. That would correspond to treating each person the way described by Vock et al. With this, the sensitivity would become: $(\sum{\Delta I(T\le t) w(t)})^{-1} \sum{\Delta I(T\le t) I(Pred = 1)w(t)})$ which is the weighted number of people who are both predicted to be events and actually are events divided by the weighted number of people who actually are events. The code I put above is my attempt and implementing that.

EmilHvitfeldt commented 6 months ago

thank you for the thorough answer! I need to take some time to fully read this through, thank you!

github-actions[bot] commented 5 months ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.