kdpsingh / runway

Visualizing Prediction Model Performance
Other
75 stars 11 forks source link

thresperf_plot (single and multi) do not seem to work, likely because of decode_data function #3

Open aliarsalankazmi opened 4 years ago

aliarsalankazmi commented 4 years ago

Thanks for this wonderful package!

Playing with the threshperf plot functions, I see that while I can run the package examples below:

library(runway)   
threshperf_plot(single_model_dataset,   
                outcome = 'outcomes',    
                prediction = 'predictions')

Using the function on another simple model does not work:

library(tidyverse)
library(rms)

m1 <- glm(am ~ wt, data = mtcars, family = binomial)
m2 <- lrm(am ~ rcs(wt), data = mtcars)

res1 <- predict(m1, type = "response") %>% tibble(prediction = .) %>% bind_cols(mtcars %>% select(am))
res2 <- predict(m2, type = "fitted") %>% tibble(prediction = .) %>% bind_cols(mtcars %>% select(am))
allRes <- bind_rows(res1 %>% mutate(model = "model 1"),
                    res2 %>% mutate(model = "model 2")) %>% 
  mutate(am = as.factor(am))

threshperf_plot_multi(allRes, 'am', 'prediction', 'model')

The error, if related to the code, is most likely happening in the recode_data function:

  df <- df %>% expand_preds(threshold = thresholds, inc = c(outcome, 
    prediction)) %>% dplyr::mutate(alt_pred = recode_data(df[[outcome]], 
    df[[prediction]], .threshold))

The df after expand_preds is of a different length than the one you use in recode_data, so using .threshold column from expand_preds dataframe would be a mismatch

kdpsingh commented 4 years ago

Ali, thank you for catching this issue and for flagging the potential cause. I will look into this and try to resolve in the next few days.

aliarsalankazmi commented 4 years ago

Thanks, Karandeep!

For what it is worth, I made the following changes due to which the functions are working on my end - just that I do not know whether/how to push these.

Changes were primarily on nse, i.e., non-standard evaluation (as in your code, variable names were being passed as strings, and this needs to be resolved differently in tidyverse framework.

Hope this reduces your workload 😃

threshperf <- function(df, outcome, prediction) {

  thresholds = unique(c(0,sort(unique(df[[prediction]])), 1))

  df <- dplyr::select(df, dplyr::all_of(c(outcome, prediction)))

  df <- na.omit(df)

  df_orig <- df

  # IMPORTANT because order of levels matters to yardstick
  if (getOption('yardstick.event_first', default = TRUE)) {
    df[[outcome]] <- factor(df[[outcome]], levels = c(1,0))
  } else {
    df[[outcome]] <- factor(df[[outcome]], levels = c(0,1))
  }

  dfInterim <-
    df %>%
    expand_preds(threshold = thresholds,
                 inc = c(outcome, prediction))
  dfFinal <-
    dfInterim %>%
    dplyr::mutate(alt_pred = recode_data(!!as.name(outcome), !!as.name(prediction), .threshold))

  df <- dfFinal %>% dplyr::group_by(.threshold)

  df_metrics <- df %>%
    two_class(truth = get(outcome), estimate = alt_pred)

  df_metrics <-
    df_metrics %>%
    dplyr::group_by(.threshold) %>%
    dplyr::mutate(denom =
                    dplyr::case_when(
                      .metric == 'sens' ~ sum(df_orig[[outcome]] == 1),
                      .metric == 'spec' ~ sum(df_orig[[outcome]] == 0),
                      .metric == 'ppv' ~ sum(df_orig[[prediction]] >= .threshold),
                      .metric == 'npv' ~ sum(df_orig[[prediction]] < .threshold),
                    )) %>%
    dplyr::ungroup() %>%
    dplyr::mutate(numer = round(.estimate * denom)) %>%
    na.omit()

  df_ci = Hmisc::binconf(x = df_metrics$numer, n = df_metrics$denom,
                         alpha = 0.05, method = 'wilson') %>%
    dplyr::as_tibble() %>%
    dplyr::rename(ll = Lower, ul = Upper) %>%
    dplyr::mutate_at(dplyr::vars(ul, ll), . %>% scales::oob_squish(range = c(0,1)))

  df_metrics = dplyr::bind_cols(df_metrics, df_ci)

  data.frame(df_metrics, check.names = FALSE, stringsAsFactors = FALSE)
}
threshperf_plot_multi <- function(df, outcome, prediction, model, plot_title = '') {

  how_many_models = df[[model]] %>% unique() %>% length()

  tp_data_list = list()
  for (model_name in unique(df[[model]])) {
    tp_data_list[[model_name]] <-
      threshperf(df[df[[model]] == model_name,],
                 outcome,
                 prediction)
    tp_data_list[[model_name]][[model]] <- model_name
  }

  tp_data = dplyr::bind_rows(tp_data_list)

  tp_plot =
    tp_data %>%
    dplyr::mutate(.metric = dplyr::case_when(
      .metric == 'npv' ~ 'NPV',
      .metric == 'ppv' ~ 'PPV',
      .metric == 'spec' ~ 'Specificity',
      .metric == 'sens' ~ 'Sensitivity')) %>%
    dplyr::mutate(.metric = factor(.metric, levels = c('Sensitivity', 'Specificity', 'PPV', 'NPV'))) %>%
    dplyr::mutate_at(dplyr::vars(.estimate, ll, ul), . %>% {. * 100}) %>%
    ggplot2::ggplot(ggplot2::aes(x = .threshold,
                                 y = .estimate,
                                 ymin = ll,
                                 ymax = ul,
                                 color = !!as.name(model),
                                 fill = !!as.name(model))) +
    ggplot2::geom_ribbon(alpha = 1/how_many_models) +
    ggplot2::geom_line(size = 1) +
    ggplot2::facet_grid(.metric~.) +
    ggplot2::theme_bw() +
    ggplot2::labs(x = 'Threshold', y = 'Performance (%)') +
    ggplot2::scale_color_brewer(name = 'Models', palette = 'Set1') +
    ggplot2::scale_fill_brewer(name = 'Models', palette = 'Set1') +
    ggplot2::ggtitle(plot_title)

  threshold_dist_plot <- ggplot2::ggplot(df, ggplot2::aes(x = !!as.name(prediction))) +
    ggplot2::geom_density(alpha = 1/how_many_models, ggplot2::aes(fill = !!as.name(model), color = !!as.name(model))) +
    ggplot2::scale_x_continuous(limits = c(0, 1), breaks = seq(0, 1, by = 0.1)) +
    # scale_color_viridis(discrete = TRUE, option = 'cividis', begin = 0.5) +
    # scale_fill_viridis(discrete = TRUE, option = 'cividis', begin = 0.5) +
    ggplot2::xlab("") +
    ggplot2::ylab("") +
    ggplot2::scale_color_brewer(palette = 'Set1') +
    ggplot2::scale_fill_brewer(palette = 'Set1') +
    ggplot2::theme_minimal() +
    ggeasy::easy_remove_y_axis() +
    #  easy_remove_x_axis(what = c('ticks','line')) +
    ggeasy::easy_remove_legend(fill, color) +
    ggplot2::theme_void()

  patchwork::plot_spacer() +
    (tp_plot / threshold_dist_plot + patchwork::plot_layout(heights = c(10,1))) +
    patchwork::plot_spacer() +
    patchwork::plot_layout(widths = c(1,2,1))
}
kdpsingh commented 4 years ago

@aliarsalankazmi Thank you so much for troubleshooting and proposing a fix. As you point out, I had switched this from NSE to standard eval and must have made an error there. I should have a fix posted by Monday.

kdpsingh commented 4 years ago

Thank you @aliarsalankazmi for sharing a working example of code. Sorry it took me a while to get to this. I didn't incorporate all of the changes you suggested but tried addressing this with the minimal amount of changes.

The issue, as I understand it, wasn't actually a standard evaluation vs. non-standard evaluation issue. I had rewritten the recode_data() function from the probably R package to use standard eval.

Instead, I think the issue was actually related to my chaining of expand_preds() to the dplyr::mutate() because the df[[predictions]] and df[[outcomes]] did not have the same dimensions after the use of expand_preds(). By breaking this into 2 steps, this should fix the issue.

If you have a chance, would love to get confirmation that the issue is fixed. If not, I'll go back and take another look.

kdpsingh commented 4 years ago

I'll close this issue for now. If you find that my attempted bug fix did not work, post a follow-up message and I'll re-open this issue. Thanks!

aliarsalankazmi commented 4 years ago

Hi Karandeep,

Thanks for the efforts! Just checked this, but still facing the error shown below:

remotes::install_github('ML4LHS/runway')
library(runway)
library(tidyverse)
library(rms)

m1 <- glm(am ~ wt, data = mtcars, family = binomial)
m2 <- lrm(am ~ rcs(wt), data = mtcars)

res1 <- predict(m1, type = "response") %>% tibble(prediction = .) %>% bind_cols(mtcars %>% select(am))
res2 <- predict(m2, type = "fitted") %>% tibble(prediction = .) %>% bind_cols(mtcars %>% select(am))
allRes <- bind_rows(res1 %>% mutate(model = "model 1"),
                    res2 %>% mutate(model = "model 2")) %>% 
  mutate(am = as.factor(am))

threshperf_plot_multi(allRes, 'am', 'prediction', 'model')
Error: Problem with `mutate()` input `alt_pred`.
x Must extract column with a single valid subscript.
x Can't convert from <double> to <integer> due to loss of precision.
i Input `alt_pred` is `recode_data(df[[outcome]], df[[prediction]], .threshold)`.
kdpsingh commented 3 years ago

Thanks @aliarsalankazmi for sharing an example. Sorry for taking a while to get back to you. One of my students and I had a chance to look at this. We found 2 reasons why the code isn't working: one is a bug is in our code as one is a design flaw that we plan to fix.

Here are the 2 issues:

  1. Because we use "get()" in our code, there is a bug where the prediction column can't be named "prediction." This is a bug that we will fix shortly.
  2. Right now, runway expects outcomes to be numeric (1s and 0s) and not factors. This was done to ensure that the correct class was being considered as positive. I'm going to change this so that the outcome needs to be a factor/character but that you have to manually specify in an argument as to which class you want to be considered as positive. Should have a fix up in the next few days.