tidymodels / parsnip

A tidy unified interface to models
https://parsnip.tidymodels.org
Other
599 stars 89 forks source link

update predict.modelfit(type = "quantile") #1203

Open topepo opened 2 months ago

topepo commented 2 months ago

We are adding a mode for quantile regression but have one engine that already enables such prediction (using the censored regression mode).

We should allow that but make some adjustments to harmonize both approaches.

topepo commented 1 month ago

Some notes...

The problem is that we have a predict() method that takes type = "quantile".

With the new quantile regression mode, we specify the quantile levels with set_mode(). The current predict() method has a quantile argument, which is problematic.

A few models have quantile prediction methods. Two survival engines for parametric models (flexsurv and survival) have methods. Also, the bayesian package has this prediction type of regression and classification models.

Proposed changes:

topepo commented 1 month ago

Regarding the bayesian package... it will be a breaking change. However, the package doesn't really follow any of our guidelines for naming arguments/prediction columns and using tidy data formats.

library(tidymodels)
library(bayesian)
#> Loading required package: brms
#> Loading required package: Rcpp
#> 
#> Attaching package: 'Rcpp'
#> The following object is masked from 'package:rsample':
#> 
#>     populate
#> Loading 'brms' package (version 2.21.0). Useful instructions
#> can be found by typing help('brms'). A more detailed introduction
#> to the package is available through vignette('brms_overview').
#> 
#> Attaching package: 'brms'
#> The following object is masked from 'package:dials':
#> 
#>     mixture
#> The following object is masked from 'package:stats':
#> 
#>     ar
# regression example

bayesian_fit <-
  bayesian() %>%
  set_mode("regression") %>%
  set_engine("brms") %>%
  fit(
    rating ~ treat + period + carry + (1 | subject),
    data = inhaler
  )
#> Compiling Stan program...
#> Trying to compile a simple C file
#> Running /Library/Frameworks/R.framework/Resources/bin/R CMD SHLIB foo.c
#> using C compiler: ‘Apple clang version 15.0.0 (clang-1500.3.9.4)’
#> using SDK: ‘’
#> clang -arch arm64 -I"/Library/Frameworks/R.framework/Resources/include" -DNDEBUG   -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/Rcpp/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/unsupported"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/BH/include" -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/src/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppParallel/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/rstan/include" -DEIGEN_NO_DEBUG  -DBOOST_DISABLE_ASSERTS  -DBOOST_PENDING_INTEGER_LOG2_HPP  -DSTAN_THREADS  -DUSE_STANC3 -DSTRICT_R_HEADERS  -DBOOST_PHOENIX_NO_VARIADIC_EXPRESSION  -D_HAS_AUTO_PTR_ETC=0  -include '/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/stan/math/prim/fun/Eigen.hpp'  -D_REENTRANT -DRCPP_PARALLEL_USE_TBB=1   -I/opt/R/arm64/include    -fPIC  -falign-functions=64 -Wall -g -O2  -c foo.c -o foo.o
#> In file included from <built-in>:1:
#> In file included from /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/stan/math/prim/fun/Eigen.hpp:22:
#> In file included from /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/Eigen/Dense:1:
#> In file included from /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/Eigen/Core:19:
#> /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/Eigen/src/Core/util/Macros.h:679:10: fatal error: 'cmath' file not found
#> #include <cmath>
#>          ^~~~~~~
#> 1 error generated.
#> make: *** [foo.o] Error 1
#> Start sampling
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 1).
#> Chain 1: 
#> Chain 1: Gradient evaluation took 7.7e-05 seconds
#> Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 0.77 seconds.
#> Chain 1: Adjust your expectations accordingly!
#> Chain 1: 
#> Chain 1: 
#> Chain 1: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 1: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 1: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 1: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 1: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 1: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 1: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 1: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 1: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 1: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 1: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 1: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 1: 
#> Chain 1:  Elapsed Time: 0.748 seconds (Warm-up)
#> Chain 1:                0.356 seconds (Sampling)
#> Chain 1:                1.104 seconds (Total)
#> Chain 1: 
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 2).
#> Chain 2: 
#> Chain 2: Gradient evaluation took 2.9e-05 seconds
#> Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 0.29 seconds.
#> Chain 2: Adjust your expectations accordingly!
#> Chain 2: 
#> Chain 2: 
#> Chain 2: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 2: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 2: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 2: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 2: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 2: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 2: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 2: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 2: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 2: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 2: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 2: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 2: 
#> Chain 2:  Elapsed Time: 0.726 seconds (Warm-up)
#> Chain 2:                0.355 seconds (Sampling)
#> Chain 2:                1.081 seconds (Total)
#> Chain 2: 
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 3).
#> Chain 3: 
#> Chain 3: Gradient evaluation took 2.6e-05 seconds
#> Chain 3: 1000 transitions using 10 leapfrog steps per transition would take 0.26 seconds.
#> Chain 3: Adjust your expectations accordingly!
#> Chain 3: 
#> Chain 3: 
#> Chain 3: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 3: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 3: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 3: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 3: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 3: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 3: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 3: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 3: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 3: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 3: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 3: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 3: 
#> Chain 3:  Elapsed Time: 0.7 seconds (Warm-up)
#> Chain 3:                0.355 seconds (Sampling)
#> Chain 3:                1.055 seconds (Total)
#> Chain 3: 
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 4).
#> Chain 4: 
#> Chain 4: Gradient evaluation took 2.3e-05 seconds
#> Chain 4: 1000 transitions using 10 leapfrog steps per transition would take 0.23 seconds.
#> Chain 4: Adjust your expectations accordingly!
#> Chain 4: 
#> Chain 4: 
#> Chain 4: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 4: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 4: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 4: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 4: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 4: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 4: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 4: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 4: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 4: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 4: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 4: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 4: 
#> Chain 4:  Elapsed Time: 0.734 seconds (Warm-up)
#> Chain 4:                0.355 seconds (Sampling)
#> Chain 4:                1.089 seconds (Total)
#> Chain 4:

# Results are not in any type of tidy format or follow the tidymodel rules for
# naming prediction columns. 
predict(bayesian_fit, inhaler, type = "quantile", quantile = c(.3, .5, .7))
#> Warning in c(0.3, 0.5, 0.7): For regression models, making quantile prediction requires a model with a
#> "quantile regression" mode as of parsnip version 1.3.0.
#> # A tibble: 572 × 5
#>    Estimate Est.Error   Q30   Q50   Q70
#>       <dbl>     <dbl> <dbl> <dbl> <dbl>
#>  1     1.21     0.581 0.911  1.22  1.51
#>  2     1.19     0.591 0.877  1.18  1.50
#>  3     1.21     0.592 0.907  1.20  1.52
#>  4     1.19     0.612 0.859  1.18  1.52
#>  5     1.20     0.613 0.886  1.19  1.50
#>  6     1.19     0.593 0.886  1.17  1.48
#>  7     1.22     0.595 0.910  1.21  1.51
#>  8     1.20     0.599 0.901  1.21  1.52
#>  9     1.19     0.595 0.875  1.18  1.51
#> 10     1.18     0.594 0.877  1.18  1.50
#> # ℹ 562 more rows
# Classification example

# data from: https://stats.oarc.ucla.edu/r/dae/mixed-effects-logistic-regression/
hdp <- 
  read.csv("https://stats.idre.ucla.edu/stat/data/hdp.csv") %>% 
  mutate(
    Married = factor(Married, levels = 0:1, labels = c("no", "yes")),
    DID = factor(DID),
    HID = factor(HID),
    CancerStage = factor(CancerStage),
    remission = factor(ifelse(remission == 1, "yes", "no"))
  )

bayesian_fit <-
  bayesian(family = bernoulli(link = "logit")) %>%
  set_mode("classification") %>%
  set_engine("brms") %>%
  fit(remission ~ IL6 + CRP + (1 | DID), data = hdp)
#> Compiling Stan program...
#> Trying to compile a simple C file
#> Running /Library/Frameworks/R.framework/Resources/bin/R CMD SHLIB foo.c
#> using C compiler: ‘Apple clang version 15.0.0 (clang-1500.3.9.4)’
#> using SDK: ‘’
#> clang -arch arm64 -I"/Library/Frameworks/R.framework/Resources/include" -DNDEBUG   -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/Rcpp/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/unsupported"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/BH/include" -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/src/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppParallel/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/rstan/include" -DEIGEN_NO_DEBUG  -DBOOST_DISABLE_ASSERTS  -DBOOST_PENDING_INTEGER_LOG2_HPP  -DSTAN_THREADS  -DUSE_STANC3 -DSTRICT_R_HEADERS  -DBOOST_PHOENIX_NO_VARIADIC_EXPRESSION  -D_HAS_AUTO_PTR_ETC=0  -include '/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/stan/math/prim/fun/Eigen.hpp'  -D_REENTRANT -DRCPP_PARALLEL_USE_TBB=1   -I/opt/R/arm64/include    -fPIC  -falign-functions=64 -Wall -g -O2  -c foo.c -o foo.o
#> In file included from <built-in>:1:
#> In file included from /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/stan/math/prim/fun/Eigen.hpp:22:
#> In file included from /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/Eigen/Dense:1:
#> In file included from /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/Eigen/Core:19:
#> /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/Eigen/src/Core/util/Macros.h:679:10: fatal error: 'cmath' file not found
#> #include <cmath>
#>          ^~~~~~~
#> 1 error generated.
#> make: *** [foo.o] Error 1
#> Start sampling
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 1).
#> Chain 1: 
#> Chain 1: Gradient evaluation took 0.000487 seconds
#> Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 4.87 seconds.
#> Chain 1: Adjust your expectations accordingly!
#> Chain 1: 
#> Chain 1: 
#> Chain 1: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 1: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 1: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 1: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 1: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 1: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 1: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 1: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 1: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 1: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 1: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 1: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 1: 
#> Chain 1:  Elapsed Time: 14.297 seconds (Warm-up)
#> Chain 1:                4.489 seconds (Sampling)
#> Chain 1:                18.786 seconds (Total)
#> Chain 1: 
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 2).
#> Chain 2: 
#> Chain 2: Gradient evaluation took 0.000297 seconds
#> Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 2.97 seconds.
#> Chain 2: Adjust your expectations accordingly!
#> Chain 2: 
#> Chain 2: 
#> Chain 2: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 2: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 2: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 2: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 2: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 2: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 2: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 2: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 2: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 2: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 2: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 2: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 2: 
#> Chain 2:  Elapsed Time: 12.056 seconds (Warm-up)
#> Chain 2:                4.56 seconds (Sampling)
#> Chain 2:                16.616 seconds (Total)
#> Chain 2: 
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 3).
#> Chain 3: 
#> Chain 3: Gradient evaluation took 0.000302 seconds
#> Chain 3: 1000 transitions using 10 leapfrog steps per transition would take 3.02 seconds.
#> Chain 3: Adjust your expectations accordingly!
#> Chain 3: 
#> Chain 3: 
#> Chain 3: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 3: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 3: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 3: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 3: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 3: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 3: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 3: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 3: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 3: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 3: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 3: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 3: 
#> Chain 3:  Elapsed Time: 12.867 seconds (Warm-up)
#> Chain 3:                4.546 seconds (Sampling)
#> Chain 3:                17.413 seconds (Total)
#> Chain 3: 
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 4).
#> Chain 4: 
#> Chain 4: Gradient evaluation took 0.000295 seconds
#> Chain 4: 1000 transitions using 10 leapfrog steps per transition would take 2.95 seconds.
#> Chain 4: Adjust your expectations accordingly!
#> Chain 4: 
#> Chain 4: 
#> Chain 4: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 4: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 4: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 4: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 4: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 4: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 4: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 4: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 4: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 4: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 4: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 4: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 4: 
#> Chain 4:  Elapsed Time: 12.375 seconds (Warm-up)
#> Chain 4:                4.507 seconds (Sampling)
#> Chain 4:                16.882 seconds (Total)
#> Chain 4:
#> Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
#> Running the chains for more iterations may help. See
#> https://mc-stan.org/misc/warnings.html#bulk-ess

# This doesn't seem to work:
predict(bayesian_fit, hdp %>% select(-remission), type = "quantile", 
        quantile = c(.3, .5, .7))
#> # A tibble: 8,525 × 5
#>    Estimate Est.Error   Q30   Q50   Q70
#>       <dbl>     <dbl> <dbl> <dbl> <dbl>
#>  1   0.0192     0.137     0     0     0
#>  2   0.0335     0.180     0     0     0
#>  3   0.0158     0.125     0     0     0
#>  4   0.03       0.171     0     0     0
#>  5   0.026      0.159     0     0     0
#>  6   0.031      0.173     0     0     0
#>  7   0.0265     0.161     0     0     0
#>  8   0.0245     0.155     0     0     0
#>  9   0.0215     0.145     0     0     0
#> 10   0.0215     0.145     0     0     0
#> # ℹ 8,515 more rows

Created on 2024-09-16 with reprex v2.1.1

hfrick commented 1 month ago

Currrently we do not produce any interval estimate for quantile predictions, we will remove the interval and level arguments to predict_quantile().

That's incorrect, we do produce them for the flexsurv and flexsurvspline engines in censored for survial_reg() models. I'd like us to keep that functionality.

library(censored)
#> Loading required package: parsnip
#> Loading required package: survival

# flexsurv engine
set.seed(1)
fit_s <- survival_reg() %>%
  set_engine("flexsurv") %>%
  set_mode("censored regression") %>%
  fit(Surv(stop, event) ~ rx + size + enum, data = bladder)

pred <- predict(fit_s,
  new_data = bladder[1:3, ], type = "quantile",
  interval = "confidence", level = 0.7
)
pred
#> # A tibble: 3 × 1
#>   .pred           
#>   <list>          
#> 1 <tibble [9 × 4]>
#> 2 <tibble [9 × 4]>
#> 3 <tibble [9 × 4]>
pred$.pred[[1]]
#> # A tibble: 9 × 4
#>   .quantile .pred_quantile .pred_lower .pred_upper
#>       <dbl>          <dbl>       <dbl>       <dbl>
#> 1       0.1           3.57        2.75        4.46
#> 2       0.2           7.33        5.83        8.86
#> 3       0.3          11.5         9.33       13.7 
#> 4       0.4          16.2        13.3        19.4 
#> 5       0.5          21.7        18.0        25.9 
#> 6       0.6          28.3        23.6        33.8 
#> 7       0.7          36.8        30.8        44.1 
#> 8       0.8          48.5        40.5        58.5 
#> 9       0.9          68.4        56.4        83.6

# flexsurvspline engine
set.seed(1)
fit_s <- survival_reg() %>%
  set_engine("flexsurvspline", k = 1) %>%
  set_mode("censored regression") %>%
  fit(Surv(stop, event) ~ rx + size + enum, data = bladder)

pred <- predict(fit_s,
  new_data = bladder[1:3, ], type = "quantile",
  interval = "confidence", level = 0.7
)
pred
#> # A tibble: 3 × 1
#>   .pred           
#>   <list>          
#> 1 <tibble [9 × 4]>
#> 2 <tibble [9 × 4]>
#> 3 <tibble [9 × 4]>
pred$.pred[[1]]
#> # A tibble: 9 × 4
#>   .quantile .pred_quantile .pred_lower .pred_upper
#>       <dbl>          <dbl>       <dbl>       <dbl>
#> 1       0.1           3.86        3.08        4.70
#> 2       0.2           7.17        5.90        8.67
#> 3       0.3          10.8         8.94       13.1 
#> 4       0.4          15.2        12.6        18.3 
#> 5       0.5          20.6        17.2        24.8 
#> 6       0.6          27.6        23.0        33.5 
#> 7       0.7          37.1        31.1        45.2 
#> 8       0.8          51.2        42.4        64.3 
#> 9       0.9          76.2        61.4       100.

Created on 2024-09-17 with reprex v2.1.0

hfrick commented 1 month ago

We will want to transition the quantile prediction type to be specifically reserved for regression models built to predict quantiles. Some ordinary regression models can compute quantiles of the prediction distribution, but these are not optimized for accuracy. We can grandfather the existing censored regression models as-is since they cannot have a censored regression mode and a quantile regression mode.

Why do we want to reserve type = "quantile" for models with mode = "quantile regression"? Wouldn't the mode be enough distinction? We can document that only the quantiles predicted by quantile regression models are optimized for accuracy but still allow other types of quantiles.

topepo commented 1 month ago

For survival::survreg() objects, the quantile levels do not appear to be stored anywhere in the output of predict(). We may need a wrapper to add an attribute or to pre-format it into a tidy format. I've exported parsnip::matrix_to_quantile_pred() since the output is similar to that produced by quantreg.

hfrick commented 1 month ago

Adding a todo as part of this: the docs for predict.model_fit() currently describe the return value for predict(type = "quantile") as a list column. This needs updating to the new vcts class -- and, more user-facing, the new column name for this.