tidymodels / yardstick

Tidy methods for measuring model performance
https://yardstick.tidymodels.org/
Other
370 stars 54 forks source link

Survival *_vec metrics seem to assume the .data.frame version of inputs #484

Open tripartio opened 9 months ago

tripartio commented 9 months ago

Hello. This is my first time here, so I will take the opportunity to first of all thank you for your fantastic package, which is an essential part of my ML workflow.

I'm having some trouble with the *`_vec` versions of some survival metrics** in the development version of {yardstick} (version 1.2.0.9003). (I don't think these metrics exist in the 1.2.0 CRAN version.) Here is a reprex:

library(yardstick)

lung_surv |> 
  roc_auc_survival(
    truth = surv_obj,
    .pred
  )
#> # A tibble: 5 × 4
#>   .metric          .estimator .eval_time .estimate
#>   <chr>            <chr>           <dbl>     <dbl>
#> 1 roc_auc_survival standard          100     0.659
#> 2 roc_auc_survival standard          200     0.679
#> 3 roc_auc_survival standard          300     0.688
#> 4 roc_auc_survival standard          400     0.648
#> 5 roc_auc_survival standard          500     0.662

roc_auc_survival_vec(
  truth = lung_surv$surv_obj,
  estimate = lung_surv$.pred_time
)
#> Error in `roc_curve_survival_vec()`:
#> ! `estimate` should be a list, not a a double vector.

lung_surv |> 
  brier_survival(
    truth = surv_obj,
    .pred
  )
#> # A tibble: 5 × 4
#>   .metric        .estimator .eval_time .estimate
#>   <chr>          <chr>           <dbl>     <dbl>
#> 1 brier_survival standard          100     0.109
#> 2 brier_survival standard          200     0.194
#> 3 brier_survival standard          300     0.219
#> 4 brier_survival standard          400     0.222
#> 5 brier_survival standard          500     0.197

brier_survival_vec(
  truth = lung_surv$surv_obj,
  estimate = lung_surv$.pred_time
)
#> Error in `brier_survival_vec()`:
#> ! `estimate` should be a list, not a a double vector.

lung_surv |> 
  brier_survival_integrated(
    truth = surv_obj,
    .pred
  )
#> # A tibble: 1 × 3
#>   .metric                   .estimator .estimate
#>   <chr>                     <chr>          <dbl>
#> 1 brier_survival_integrated standard       0.158

brier_survival_integrated_vec(
  truth = lung_surv$surv_obj,
  estimate = lung_surv$.pred_time
)
#> Error in `brier_survival_integrated_vec()`:
#> ! `estimate` should be a list, not a a double vector.

Created on 2024-01-18 with reprex v2.1.0

I initially tried to fix the code and make a PR, but I soon found that the structure of interweaving functions is quite complicated; it probably requires someone internal to the function creation to resolve this. But I will try to share my attempts to isolate the bug.

It seems that the immediate source of the bug might be in the check_dynamic_survival_metric() function in the file check-metric.R, which all six metric functions call (that is, the *.data.frame versions as well as the *_vec versions):

check_dynamic_survival_metric <- function(truth,
                                          estimate,
                                          case_weights,
                                          call = caller_env()) {
  validate_surv_truth_list_estimate(truth, estimate, call = call)
  validate_case_weights(case_weights, size = nrow(truth), call = call)
}

Specifically, these metric functions all call validate_surv_truth_list_estimate(). However, that validation function seems to assume a very particular structure of the dataframe or tibble that is passed to the metric functions. Other than the fact that how to generate that particular structure does not seem to be well documented, the bug seems to be in that the *_vec functions also assume that dataframe structure.

Following the code for the roc_auc_survival() and the roc_auc_survival_vec() functions (roc_auc_survival_vec() is the function that I personally need), it seems that in the surv-roc_curve_survival.R, anything that enters the roc_curve_survival_vec() is passed on to roc_curve_survival_impl(), which assumes that its input truth argument is a dataframe in the desired format, rather than a vector. I would expect that the roc_curve_survival_vec() should be called only when the dataframe method has already processed the dataframe into vector form, but that does not seem to be the case. I did not investigate brier_survival_vec() and brier_survival_integrated_vec() in such detail, but I suspect that the issue might be similar for those functions.

Perhaps whoever wrote these metric functions might have written the dataframe versions from beginning to end and then added the vector versions without thoroughly testing them. If so, the bug might be resolved perhaps by rewriting the vector versions first and then writing the dataframe versions afterwards. I don't expect that any truly new code would be needed, but it seems that a lot of existing code will need to be rearranged.

I hope that the information I've given here can help to quickly identify and resolve these bugs.

EmilHvitfeldt commented 9 months ago

Hello @Tripartio 👋 welcome!

I think the issue you are running into, is happening because of a typo. When using roc_auc_survival() you are using .pred but then when you try roc_auc_survival_vec() you switch to .pred_time which i causing your error. The required columns should stay the same when you switch from the metric to the _vec() function.

library(yardstick)

lung_surv |> 
  roc_auc_survival(
    truth = surv_obj,
    .pred
  )
#> # A tibble: 5 × 4
#>   .metric          .estimator .eval_time .estimate
#>   <chr>            <chr>           <dbl>     <dbl>
#> 1 roc_auc_survival standard          100     0.659
#> 2 roc_auc_survival standard          200     0.679
#> 3 roc_auc_survival standard          300     0.688
#> 4 roc_auc_survival standard          400     0.648
#> 5 roc_auc_survival standard          500     0.662

roc_auc_survival_vec(
  truth = lung_surv$surv_obj,
  estimate = lung_surv$.pred
)
#> # A tibble: 5 × 2
#>   .eval_time .estimate
#>        <dbl>     <dbl>
#> 1        100     0.659
#> 2        200     0.679
#> 3        300     0.688
#> 4        400     0.648
#> 5        500     0.662

lung_surv |> 
  brier_survival(
    truth = surv_obj,
    .pred
  )
#> # A tibble: 5 × 4
#>   .metric        .estimator .eval_time .estimate
#>   <chr>          <chr>           <dbl>     <dbl>
#> 1 brier_survival standard          100     0.109
#> 2 brier_survival standard          200     0.194
#> 3 brier_survival standard          300     0.219
#> 4 brier_survival standard          400     0.222
#> 5 brier_survival standard          500     0.197

brier_survival_vec(
  truth = lung_surv$surv_obj,
  estimate = lung_surv$.pred
)
#> # A tibble: 5 × 2
#>   .eval_time .estimate
#>        <dbl>     <dbl>
#> 1        100     0.109
#> 2        200     0.194
#> 3        300     0.219
#> 4        400     0.222
#> 5        500     0.197

lung_surv |> 
  brier_survival_integrated(
    truth = surv_obj,
    .pred
  )
#> # A tibble: 1 × 3
#>   .metric                   .estimator .estimate
#>   <chr>                     <chr>          <dbl>
#> 1 brier_survival_integrated standard       0.158

brier_survival_integrated_vec(
  truth = lung_surv$surv_obj,
  estimate = lung_surv$.pred
)
#> [1] 0.1576877
tripartio commented 9 months ago

@EmilHvitfeldt, thanks for the response, but what you gave me does not correspond at all to my understanding of what I expect from the *_vec functions (not just for survival, but throughout yardstick). It seems that my deliberate "typo" is because I expected the function to work very differently from the way it does, in the absence of up-to-date documentation. But I think that my expectation might be more reasonable than the actual behaviour of the function.

Here is the documentation for the estimate argument for roc_auc_survival_vec():

roc_auc_survival(data, truth, ..., na_rm = TRUE, case_weights = NULL)

roc_auc_survival_vec(truth, estimate, na_rm = TRUE, case_weights = NULL, ...)

Arguments
... 
estimate      If truth is binary, a numeric vector of class probabilities corresponding to the "relevant" class. Otherwise, a matrix with as many columns as factor levels of truth. It is assumed that these are in the same order as the levels of truth.

First, the documentation for estimate seems to be copy pasta from roc_auc_vec(). It does not make much sense in the context of survival metrics. So, I was really guessing to figure out what "estimate" really means.

Second, what I expect for a *_vec function is that both the truth and the estimate would be vectors, or as close to vector as possible. For survival analysis, it is normal that truth would be a Surv() survival object. But why would estimate be a list of a special structure? That's what roc_auc_vec() essentially does already. Why would roc_auc_survival_vec() duplicate that functionality by doing little more than switching arguments around?

How do we even obtain that particular kind of prediction? It seems to me that the list structure is the very specific predict() output from {parsnip}/{censored}. Even though they are fellow members of {tidymodels}, these are distinct packages from {yardstick}. Such total dependency on other {tidymodels} packages renders these otherwise very useful metric functions completely unusable as a general metrics package independent of {tidymodels}.

If such total dependency is by design, I don't think that is a wise design decision. For example, I often use {yardstick} to calculate AUC when not using {tidymodels}. yardstick::roc_auc() is completely unusable apart from {tidymodels}, but yardstick::roc_auc_vec() works just fine because its inputs are simply two vectors, as its name implies. It makes sense for the primary function to work smoothly with and depend on {tidymodels} whereas the *_vec version is generic for any kind of use. I would expect roc_auc_survival_vec() to work the same way, but from your response, it seems that it might not.

Could you please clarify me by providing accurate documentation for the truth and estimate arguments for roc_auc_vec() and roc_auc_survival_vec()? (Perhaps if you give me a first draft here, it could eventually make it to the official documentation for the function.) That would help me understand the intention of the function.

Thanks for your time and consideration.

EmilHvitfeldt commented 9 months ago

Here is the documentation for the estimate argument for roc_auc_survival_vec():

Thank you, that is a mistake, and should be fixed. As a general rule in {yardstick}, all metrics that work like so:

metric_function(data, truth = col1, estimate = col2)

can be rewritten as

metric_function_vec(truth = data$col1, estimate = data$col2)

with the difference being that metric_function() returns a tibble, and metric_function_vec() will return the smallest output, usually a single number. All metrics are using the metric_function_vec() function internally to compute what it needs. This way, if you have the data in a tibble already you can use metric_function() but you still have access to the work horse metric_function_vec() function if you so please.

There are a couple of caveats to this. This pertains to probability, survival and curve metrics, that uses ... instead of estimate to pass in the columns, as these metrics sometimes requires multiple columns to be passed in.

metric_function(data, truth = col1, col2)

Second, what I expect for a *_vec function is that both the truth and the estimate would be vectors, or as close to vector as possible. For survival analysis, it is normal that truth would be a Surv() survival object. But why would estimate be a list of a special structure? That's what roc_auc_vec() essentially does already. Why would roc_auc_survival_vec() duplicate that functionality by doing little more than switching arguments around?

Yes, the input structure is a little more complicated than with regression of classification metrics, but I hope that we have documented the input types clearly enough like here: https://yardstick.tidymodels.org/dev/reference/brier_survival.html#details

If you want more information why that specific information is needed, I refer you to the two articles we have written about it on tidymodels.org

How do we even obtain that particular kind of prediction? It seems to me that the list structure is the very specific predict() output from {parsnip}/{censored}. Even though they are fellow members of {tidymodels}, these are distinct packages from {yardstick}. Such total dependency on other {tidymodels} packages renders these otherwise very useful metric functions completely unusable as a general metrics package independent of {tidymodels}.

As far as we can tell, the input required for these predictions, are required to calculate these metrics correctly. You should be able to generate those results without {tidymodels}, but it may require some wrangling to get the format that {yardstick} expects.

yardstick::roc_auc() is completely unusable apart from {tidymodels}

I strongly disagree on this point. yardstick::roc_auc() and yardstick::roc_auc_vec() are equally usable. It is just a matter of whether you have the data in a data.frame or not.

I'm not going to comment much on the rest of the post, because I think I have covered most of it earlier.

Thanks for your time and consideration.

That is why I'm here!

EmilHvitfeldt commented 9 months ago

Also, i hope this doesn't come off condersending, as it isn't the point 🤞

The way you calculate the ROC curve for censored data is different than how you calculate it otherwise. See references listed here to see how censored data changes the calculations

https://yardstick.tidymodels.org/dev/reference/roc_curve_survival.html#references

tripartio commented 9 months ago

@EmilHvitfeldt Thank you for your detailed response. No, it does not come off as condescending; rather, I really appreciate your time.

I will need some time to absorb all this and retest what I'm trying to do. I hope to follow up next week.