tidymodels / censored

Parsnip wrappers for survival models
https://censored.tidymodels.org/
Other
123 stars 12 forks source link

Fresh implementation aft and cox #260

Closed brunocarlin closed 4 months ago

brunocarlin commented 1 year ago

I have fixed the bug with the previous implementation it was a parse vs rlang thing

library(tidymodels)
#> Warning in system("timedatectl", intern = TRUE): running command 'timedatectl'
#> had status 1
library(censored)
#> Loading required package: survival
library(tidyverse)
library(survival)

data(cancer)

lung <- lung %>% drop_na()
lung_train <- lung[-c(1:5), ]
lung_test <- lung[1:5, ]

test_aft <-
  boost_tree()|> set_engine('xgboost') |> set_mode('censored regression')

test_aft |>
  translate()
#> Boosted Tree Model Specification (censored regression)
#> 
#> Computational engine: xgboost 
#> 
#> Model fit template:
#> censored::xgb_train_censored(x = missing_arg(), y = missing_arg(), 
#>     nthread = 1, verbose = 0, objective = "survival:aft")

set.seed(1)
bt_fit <- test_aft %>% fit(Surv(time, status) ~ ., data = lung_train)
bt_fit
#> parsnip model object
#> 
#> ##### xgb.Booster
#> raw: 36.3 Kb 
#> call:
#>   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
#>     colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
#>     subsample = 1, objective = "survival:aft"), data = x$data, 
#>     nrounds = 15, watchlist = x$watchlist, verbose = 0, nthread = 1)
#> params (as set within xgb.train):
#>   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", objective = "survival:aft", nthread = "1", validate_parameters = "TRUE"
#> xgb.attributes:
#>   niter
#> callbacks:
#>   cb.evaluation.log()
#> # of features: 8 
#> niter: 15
#> nfeatures : 8 
#> evaluation_log:
#>     iter training_aft_nloglik
#>        1            14.698676
#>        2             9.883017
#> ---                          
#>       14             4.564666
#>       15             4.553461

predict(
  bt_fit,
  lung_test,
  type = "linear_pred",
)
#> Error in `object$spec$method$pred$linear_pred$pre()`:
#> ! The objective should be survival:cox not survival:aft
#> Backtrace:
#>     ▆
#>  1. ├─stats::predict(bt_fit, lung_test, type = "linear_pred", )
#>  2. └─parsnip::predict.model_fit(...)
#>  3.   ├─parsnip::predict_linear_pred(...)
#>  4.   └─parsnip::predict_linear_pred.model_fit(...)
#>  5.     └─object$spec$method$pred$linear_pred$pre(new_data, object)
#>  6.       └─rlang::abort(glue::glue("The objective should be survival:cox not {object$fit$params$objective}")) at censored/R/boost_tree-data.R:280:8

predict(bt_fit,lung_test,type = 'time')
#> # A tibble: 5 × 1
#>   .pred_time
#>        <dbl>
#> 1      420. 
#> 2      239. 
#> 3      120. 
#> 4       78.7
#> 5      350.

test_cox <-
  boost_tree()|> set_engine('xgboost',objective = 'survival:cox')  |> set_mode('censored regression')

test_cox |>
  translate()
#> Boosted Tree Model Specification (censored regression)
#> 
#> Engine-Specific Arguments:
#>   objective = survival:cox
#> 
#> Computational engine: xgboost 
#> 
#> Model fit template:
#> censored::xgb_train_censored(x = missing_arg(), y = missing_arg(), 
#>     objective = "survival:cox", nthread = 1, verbose = 0)

set.seed(1)
bt_fit <- test_cox %>% fit(Surv(time, status) ~ ., data = lung_train)
bt_fit
#> parsnip model object
#> 
#> ##### xgb.Booster
#> raw: 40.5 Kb 
#> call:
#>   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
#>     colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
#>     subsample = 1, objective = "survival:cox"), data = x$data, 
#>     nrounds = 15, watchlist = x$watchlist, verbose = 0, nthread = 1)
#> params (as set within xgb.train):
#>   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", objective = "survival:cox", nthread = "1", validate_parameters = "TRUE"
#> xgb.attributes:
#>   niter
#> callbacks:
#>   cb.evaluation.log()
#> # of features: 8 
#> niter: 15
#> nfeatures : 8 
#> evaluation_log:
#>     iter training_cox_nloglik
#>        1             3.967019
#>        2             3.840237
#> ---                          
#>       14             3.095573
#>       15             3.054746

predict(bt_fit, lung_test, type = 'time')
#> Error in `object$spec$method$pred$time$pre()`:
#> ! The objective should be survival:aft not survival:cox
#> Backtrace:
#>     ▆
#>  1. ├─stats::predict(bt_fit, lung_test, type = "time")
#>  2. └─parsnip::predict.model_fit(bt_fit, lung_test, type = "time")
#>  3.   ├─parsnip::predict_time(object = object, new_data = new_data, ...)
#>  4.   └─parsnip::predict_time.model_fit(...)
#>  5.     └─object$spec$method$pred$time$pre(new_data, object)
#>  6.       └─rlang::abort(glue::glue("The objective should be survival:aft not {object$fit$params$objective}")) at censored/R/boost_tree-data.R:256:8

predict(bt_fit,
        lung_test,
        type = "linear_pred")
#> # A tibble: 5 × 1
#>   .pred_linear_pred
#>               <dbl>
#> 1             0.351
#> 2             4.41 
#> 3             2.23 
#> 4             4.50 
#> 5             2.36

Created on 2023-04-13 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.3 (2023-03-15) #> os Ubuntu 22.04.2 LTS #> system x86_64, linux-gnu #> ui X11 #> language (EN) #> collate C.UTF-8 #> ctype C.UTF-8 #> tz America/Sao_Paulo #> date 2023-04-13 #> pandoc 2.19.2 @ /usr/lib/rstudio-server/bin/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.3) #> broom * 1.0.4 2023-03-11 [1] CRAN (R 4.2.3) #> censored * 0.1.1.9003 2023-04-14 [1] local #> class 7.3-21 2023-01-23 [4] CRAN (R 4.2.2) #> cli 3.6.1 2023-03-23 [1] CRAN (R 4.2.3) #> codetools 0.2-19 2023-02-01 [4] CRAN (R 4.2.2) #> colorspace 2.1-0 2023-01-23 [1] CRAN (R 4.2.3) #> data.table 1.14.8 2023-02-17 [1] CRAN (R 4.2.3) #> dials * 1.2.0 2023-04-03 [1] CRAN (R 4.2.3) #> DiceDesign 1.9 2021-02-13 [1] CRAN (R 4.2.3) #> digest 0.6.31 2022-12-11 [1] CRAN (R 4.2.3) #> dplyr * 1.1.1 2023-03-22 [1] CRAN (R 4.2.3) #> evaluate 0.20 2023-01-17 [1] CRAN (R 4.2.3) #> fansi 1.0.4 2023-01-22 [1] CRAN (R 4.2.3) #> fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.2.3) #> forcats * 1.0.0 2023-01-29 [1] CRAN (R 4.2.3) #> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.2.3) #> fs 1.6.1 2023-02-06 [1] CRAN (R 4.2.3) #> furrr 0.3.1 2022-08-15 [1] CRAN (R 4.2.3) #> future 1.32.0 2023-03-07 [1] CRAN (R 4.2.3) #> future.apply 1.10.0 2022-11-05 [1] CRAN (R 4.2.3) #> generics 0.1.3 2022-07-05 [1] CRAN (R 4.2.3) #> ggplot2 * 3.4.2 2023-04-03 [1] CRAN (R 4.2.3) #> globals 0.16.2 2022-11-21 [1] CRAN (R 4.2.3) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.3) #> gower 1.0.1 2022-12-22 [1] CRAN (R 4.2.3) #> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.2.3) #> gtable 0.3.3 2023-03-21 [1] CRAN (R 4.2.3) #> hardhat 1.3.0 2023-03-30 [1] CRAN (R 4.2.3) #> hms 1.1.3 2023-03-21 [1] CRAN (R 4.2.3) #> htmltools 0.5.5 2023-03-23 [1] CRAN (R 4.2.3) #> infer * 1.0.4 2022-12-02 [1] CRAN (R 4.2.3) #> ipred 0.9-14 2023-03-09 [1] CRAN (R 4.2.3) #> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.2.3) #> jsonlite 1.8.4 2022-12-06 [1] CRAN (R 4.2.3) #> knitr 1.42 2023-01-25 [1] CRAN (R 4.2.3) #> lattice 0.20-45 2021-09-22 [4] CRAN (R 4.2.0) #> lava 1.7.2.1 2023-02-27 [1] CRAN (R 4.2.3) #> lhs 1.1.6 2022-12-17 [1] CRAN (R 4.2.3) #> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.2.3) #> listenv 0.9.0 2022-12-16 [1] CRAN (R 4.2.3) #> lubridate * 1.9.2 2023-02-10 [1] CRAN (R 4.2.3) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.3) #> MASS 7.3-58.3 2023-03-07 [4] CRAN (R 4.2.3) #> Matrix 1.5-1 2022-09-13 [4] CRAN (R 4.2.1) #> modeldata * 1.1.0 2023-01-25 [1] CRAN (R 4.2.3) #> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.2.3) #> nnet 7.3-18 2022-09-28 [4] CRAN (R 4.2.1) #> parallelly 1.35.0 2023-03-23 [1] CRAN (R 4.2.3) #> parsnip * 1.1.0 2023-04-12 [1] CRAN (R 4.2.3) #> pillar 1.9.0 2023-03-22 [1] CRAN (R 4.2.3) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.3) #> prodlim 2023.03.31 2023-04-02 [1] CRAN (R 4.2.3) #> purrr * 1.0.1 2023-01-10 [1] CRAN (R 4.2.3) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.3) #> Rcpp 1.0.10 2023-01-22 [1] CRAN (R 4.2.3) #> readr * 2.1.4 2023-02-10 [1] CRAN (R 4.2.3) #> recipes * 1.0.5 2023-02-20 [1] CRAN (R 4.2.3) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.3) #> rlang 1.1.0 2023-03-14 [1] CRAN (R 4.2.3) #> rmarkdown 2.21 2023-03-26 [1] CRAN (R 4.2.3) #> rpart 4.1.19 2022-10-21 [4] CRAN (R 4.2.1) #> rsample * 1.1.1 2022-12-07 [1] CRAN (R 4.2.3) #> rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.2.3) #> scales * 1.2.1 2022-08-20 [1] CRAN (R 4.2.3) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.3) #> stringi 1.7.12 2023-01-11 [1] CRAN (R 4.2.3) #> stringr * 1.5.0 2022-12-02 [1] CRAN (R 4.2.3) #> survival * 3.5-3 2023-02-12 [4] CRAN (R 4.2.2) #> tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.2.3) #> tidymodels * 1.0.0 2022-07-13 [1] CRAN (R 4.2.3) #> tidyr * 1.3.0 2023-01-24 [1] CRAN (R 4.2.3) #> tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.2.3) #> tidyverse * 2.0.0 2023-02-22 [1] CRAN (R 4.2.3) #> timechange 0.2.0 2023-01-11 [1] CRAN (R 4.2.3) #> timeDate 4022.108 2023-01-07 [1] CRAN (R 4.2.3) #> tune * 1.1.1 2023-04-11 [1] CRAN (R 4.2.3) #> tzdb 0.3.0 2022-03-28 [1] CRAN (R 4.2.3) #> utf8 1.2.3 2023-01-31 [1] CRAN (R 4.2.3) #> vctrs 0.6.1 2023-03-22 [1] CRAN (R 4.2.3) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.3) #> workflows * 1.1.3 2023-02-22 [1] CRAN (R 4.2.3) #> workflowsets * 1.0.1 2023-04-06 [1] CRAN (R 4.2.3) #> xfun 0.38 2023-03-24 [1] CRAN (R 4.2.3) #> xgboost 1.7.5.1 2023-03-30 [1] CRAN (R 4.2.3) #> yaml 2.3.7 2023-01-23 [1] CRAN (R 4.2.3) #> yardstick * 1.1.0 2022-09-07 [1] CRAN (R 4.2.3) #> #> [1] /home/bcarlin/R/x86_64-pc-linux-gnu-library/4.2 #> [2] /usr/local/lib/R/site-library #> [3] /usr/lib/R/site-library #> [4] /usr/lib/R/library #> #> ────────────────────────────────────────────────────────────────────────────── ```
github-actions[bot] commented 4 months ago

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.