tidymodels / censored

Parsnip wrappers for survival models
https://censored.tidymodels.org/
Other
123 stars 12 forks source link

`multi_predict(type = "survival")` is missing `penalty` column for single penalty value #267

Closed hfrick closed 10 months ago

hfrick commented 1 year ago
library(censored)
#> Loading required package: parsnip
#> Loading required package: survival
lung2 <- lung[-14, ]

set.seed(14)
f_fit <- proportional_hazards(penalty = 0.123) %>%
  set_mode("censored regression") %>%
  set_engine("glmnet") %>%
  fit(Surv(time, status) ~ age + ph.ecog, data = lung2)

pred_multi <- multi_predict(f_fit,new_data = lung2[1:3, ], type = "survival",
  eval_time = c(100, 200), penalty = 0.1)

# this should have a penalty column
pred_multi$.pred[[1]]
#> # A tibble: 2 × 2
#>   .eval_time .pred_survival
#>        <dbl>          <dbl>
#> 1        100          0.868
#> 2        200          0.680

Created on 2023-05-04 with reprex v2.0.2

hfrick commented 11 months ago

This currently doesn't work because survival_prob_coxnet() reconstructs multi from the length of the penalty, which does not allow us to specify multi = TRUE and a penalty with a single value.

The plan is to give survival_prob_coxnet() its own multi arg and call it directly from multi_predict(). We can't go through predict() as we do now because predict() won't let us pass the multi argument through. We will need to add to multi_predict() any checks on the type (and the like) that are carried out by predict().

github-actions[bot] commented 10 months ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.