chjackson / flexsurv

The flexsurv R package for flexible parametric survival and multi-state modelling
http://chjackson.github.io/flexsurv/
53 stars 28 forks source link

`predict.flexsurvreg()` doesn't nest for a single time point #191

Open hfrick opened 2 months ago

hfrick commented 2 months ago

I noticed that flexsurv's predict method does not nest the survival probabilities if there's only one time point to calculate the probability at.

While nesting isn't neccessary in that case to get to a data frame that has one row per observation, it does make the output format consistent. (I've added an example with censored which does that.)

Is that behaviour something you want preserved or could I send a PR to change that? I think it might only be a matter of changing

https://github.com/chjackson/flexsurv/blob/d5537c7b2e59d6224f844098d5b42862b948aa7a/R/predict.flexsurvreg.R#L201

but I haven't been able to test that yet. One of the tests in the corresponding test file segfaults for me locally with the current dev version and I haven't figured out why.

Here's a reprex to illustrate the behaviour:

library(flexsurv)
#> Loading required package: survival
library(censored)
#> Loading required package: parsnip

flexsurv_fit <- flexsurv::flexsurvreg(
  Surv(time, status) ~ age + sex,
  data = lung,
  dist = "weibull"
)

# for a single time point, the output is unnested
predict(flexsurv_fit, lung[2,], type = "survival", time = 100)
#> # A tibble: 1 × 2
#>   .eval_time .pred_survival
#>        <dbl>          <dbl>
#> 1        100          0.820
predict(flexsurv_fit, lung, type = "survival", time = 100)
#> # A tibble: 228 × 2
#>    .eval_time .pred_survival
#>         <dbl>          <dbl>
#>  1        100          0.803
#>  2        100          0.820
#>  3        100          0.849
#>  4        100          0.847
#>  5        100          0.840
#>  6        100          0.803
#>  7        100          0.887
#>  8        100          0.882
#>  9        100          0.856
#> 10        100          0.837
#> # ℹ 218 more rows

# with more than one time point, the output is nested
predict(flexsurv_fit, lung[2,], type = "survival", time = c(100, 200))
#> # A tibble: 1 × 1
#>   .pred           
#>   <list>          
#> 1 <tibble [2 × 2]>
predict(flexsurv_fit, lung, type = "survival", time = c(100, 200))
#> # A tibble: 228 × 1
#>    .pred           
#>    <list>          
#>  1 <tibble [2 × 2]>
#>  2 <tibble [2 × 2]>
#>  3 <tibble [2 × 2]>
#>  4 <tibble [2 × 2]>
#>  5 <tibble [2 × 2]>
#>  6 <tibble [2 × 2]>
#>  7 <tibble [2 × 2]>
#>  8 <tibble [2 × 2]>
#>  9 <tibble [2 × 2]>
#> 10 <tibble [2 × 2]>
#> # ℹ 218 more rows

censored_fit <- survival_reg() %>%
  set_engine("flexsurv") %>%
  fit(Surv(time, status) ~ age + sex, data = lung)

# same format regardless of the number of time points (or number of observations)
predict(censored_fit, lung[2,], type = "survival", eval_time = 200)
#> # A tibble: 1 × 1
#>   .pred           
#>   <list>          
#> 1 <tibble [1 × 2]>
predict(censored_fit, lung, type = "survival", eval_time = 200)
#> # A tibble: 228 × 1
#>    .pred           
#>    <list>          
#>  1 <tibble [1 × 2]>
#>  2 <tibble [1 × 2]>
#>  3 <tibble [1 × 2]>
#>  4 <tibble [1 × 2]>
#>  5 <tibble [1 × 2]>
#>  6 <tibble [1 × 2]>
#>  7 <tibble [1 × 2]>
#>  8 <tibble [1 × 2]>
#>  9 <tibble [1 × 2]>
#> 10 <tibble [1 × 2]>
#> # ℹ 218 more rows

predict(censored_fit, lung[2,], type = "survival", eval_time = c(100, 200))
#> # A tibble: 1 × 1
#>   .pred           
#>   <list>          
#> 1 <tibble [2 × 2]>
predict(censored_fit, lung, type = "survival", eval_time = c(100, 200))
#> # A tibble: 228 × 1
#>    .pred           
#>    <list>          
#>  1 <tibble [2 × 2]>
#>  2 <tibble [2 × 2]>
#>  3 <tibble [2 × 2]>
#>  4 <tibble [2 × 2]>
#>  5 <tibble [2 × 2]>
#>  6 <tibble [2 × 2]>
#>  7 <tibble [2 × 2]>
#>  8 <tibble [2 × 2]>
#>  9 <tibble [2 × 2]>
#> 10 <tibble [2 × 2]>
#> # ℹ 218 more rows

Created on 2024-04-19 with reprex v2.1.0

chjackson commented 2 months ago

@mattwarkentin would you be able to look at this, as the original author? I have no idea about the pros and cons of this.

mattwarkentin commented 2 months ago

Absolutely! Happy to take a look. Will report back.

mattwarkentin commented 2 months ago

I don't think I have a strong opinion on the nesting/non-nesting for single row per-observation predictions. I do agree with @hfrick that format consistency is a nice feature. I think if we want to revert to always nesting results into list-columns of tibbles that would be okay. I think for time-to-event models like these, it will generally be more common to make predictions at multiple horizons, so it may cause friction to get a different output when only predicting for a single time point, so I think thats an argument for always returning the same format. What do you think?

For context, I think I chose to have single time-point predictions return as non-nested since it saved the user an extra step of unnesting a single observation, and also predict() methods for many standard regression models (tidymodels ecosystem excluded) return vectors and this felt like a comparable approach.

But I could pretty easily be convinced to just go with a consistent output format over forced convenience. Moreover, we have already gone to lengths to have the predict.flexsurvreg() method be "tidy" to play nicely with censored, so this feels like a reasonable change to be pretty much fully consistent between the two packages.