mlr-org / mlr3proba

Probabilistic Learning for mlr3
https://mlr3proba.mlr-org.com
GNU Lesser General Public License v3.0
126 stars 20 forks source link

Warning when using `breslow` PipeOp with `AutoTuner` #344

Closed jemus42 closed 8 months ago

jemus42 commented 8 months ago

I'm seeing a PipeOp-related issue with the new breslow estimator:

6: PipeOp PipeOpBreslow has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.

The hash and phash of a PipeOp must differ when it represents a different operation; since PipeOpBreslow has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.

While it says to be a warning, it suffices to make the fallback learner kick in in our benchmark, which then capsizes the entire learner unfortunately.

Reprex

library(mlr3)
library(mlr3pipelines)
library(mlr3proba)
library(mlr3extralearners)
library(mlr3tuning)
#> Loading required package: paradox

base_lrn = lrn("surv.xgboost", tree_method = "hist", booster = "gbtree", objective = "survival:cox")

grph_lrn = po("encode", method = "treatment") %>>%
  ppl("distrcompositor", learner = base_lrn, form = "ph", estimator = "breslow") |>
  as_learner()

# Train/predict without AutoTuner seems fine
grph_lrn$train(tsk("whas"))
(preds = grph_lrn$predict(tsk("whas")))
#> <PredictionSurv> for 481 observations:
#>     row_ids time status       crank          lp     distr
#>           1    1   TRUE  1.93616793  1.93616793 <list[1]>
#>           2    1   TRUE  1.21409347  1.21409347 <list[1]>
#>           3    1   TRUE  1.93616793  1.93616793 <list[1]>
#> ---                                                      
#>         479  807  FALSE -0.54297472 -0.54297472 <list[1]>
#>         480   39   TRUE -0.50549181 -0.50549181 <list[1]>
#>         481   68   TRUE -0.02851581 -0.02851581 <list[1]>
preds$score(msr("surv.cindex"))
#> surv.cindex 
#>   0.7956677

at = AutoTuner$new(
  learner = grph_lrn,
  search_space = ps(
    surv.xgboost.nrounds = p_int(10, 20),
    surv.xgboost.eta = p_dbl(0, 1)
  ),
  resampling = rsmp("holdout"),
  measure = msr("surv.rcll"),
  terminator = trm("evals", n_evals = 3),
  tuner = tnr("random_search"),
  store_tuning_instance = FALSE,
  store_benchmark_result = FALSE,
  store_models = FALSE
)

at$train(tsk("whas"))
#> INFO  [16:58:08.210] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=3, k=0]'
#> INFO  [16:58:08.233] [bbotk] Evaluating 1 configuration(s)
#> INFO  [16:58:08.241] [mlr3] Running benchmark with 1 resampling iterations
#> INFO  [16:58:08.259] [mlr3] Applying learner 'encode.surv.xgboost' on task 'whas' (iter 1/1)
#> Warning: PipeOp PipeOpBreslow has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.
#> 
#> The hash and phash of a PipeOp must differ when it represents a different operation; since PipeOpBreslow has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.
#> 
#> This warning will become an error in the future.
#> INFO  [16:58:08.320] [mlr3] Finished benchmark
#> Warning: PipeOp PipeOpBreslow has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.
#> 
#> The hash and phash of a PipeOp must differ when it represents a different operation; since PipeOpBreslow has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.
#> 
#> This warning will become an error in the future.
#> INFO  [16:58:08.337] [bbotk] Result of batch 1:
#> INFO  [16:58:08.338] [bbotk]  surv.xgboost.nrounds surv.xgboost.eta surv.rcll warnings errors
#> INFO  [16:58:08.338] [bbotk]                    10        0.4507898  14.05608        0      0
#> INFO  [16:58:08.338] [bbotk]  runtime_learners
#> INFO  [16:58:08.338] [bbotk]             0.053
#> INFO  [16:58:08.339] [bbotk] Evaluating 1 configuration(s)
#> INFO  [16:58:08.343] [mlr3] Running benchmark with 1 resampling iterations
#> INFO  [16:58:08.345] [mlr3] Applying learner 'encode.surv.xgboost' on task 'whas' (iter 1/1)
#> Warning: PipeOp PipeOpBreslow has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.
#> 
#> The hash and phash of a PipeOp must differ when it represents a different operation; since PipeOpBreslow has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.
#> 
#> This warning will become an error in the future.
#> INFO  [16:58:08.406] [mlr3] Finished benchmark
#> Warning: PipeOp PipeOpBreslow has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.
#> 
#> The hash and phash of a PipeOp must differ when it represents a different operation; since PipeOpBreslow has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.
#> 
#> This warning will become an error in the future.
#> INFO  [16:58:08.425] [bbotk] Result of batch 2:
#> INFO  [16:58:08.425] [bbotk]  surv.xgboost.nrounds surv.xgboost.eta surv.rcll warnings errors
#> INFO  [16:58:08.425] [bbotk]                    19        0.9423549  14.82386        0      0
#> INFO  [16:58:08.425] [bbotk]  runtime_learners
#> INFO  [16:58:08.425] [bbotk]             0.055
#> INFO  [16:58:08.427] [bbotk] Evaluating 1 configuration(s)
#> INFO  [16:58:08.430] [mlr3] Running benchmark with 1 resampling iterations
#> INFO  [16:58:08.433] [mlr3] Applying learner 'encode.surv.xgboost' on task 'whas' (iter 1/1)
#> Warning: PipeOp PipeOpBreslow has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.
#> 
#> The hash and phash of a PipeOp must differ when it represents a different operation; since PipeOpBreslow has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.
#> 
#> This warning will become an error in the future.
#> INFO  [16:58:08.491] [mlr3] Finished benchmark
#> Warning: PipeOp PipeOpBreslow has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.
#> 
#> The hash and phash of a PipeOp must differ when it represents a different operation; since PipeOpBreslow has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.
#> 
#> This warning will become an error in the future.
#> INFO  [16:58:08.506] [bbotk] Result of batch 3:
#> INFO  [16:58:08.506] [bbotk]  surv.xgboost.nrounds surv.xgboost.eta surv.rcll warnings errors
#> INFO  [16:58:08.506] [bbotk]                    13         0.860655  14.71267        0      0
#> INFO  [16:58:08.506] [bbotk]  runtime_learners
#> INFO  [16:58:08.506] [bbotk]             0.052
#> INFO  [16:58:08.510] [bbotk] Finished optimizing after 3 evaluation(s)
#> INFO  [16:58:08.510] [bbotk] Result:
#> INFO  [16:58:08.510] [bbotk]  surv.xgboost.nrounds surv.xgboost.eta learner_param_vals  x_domain surv.rcll
#> INFO  [16:58:08.510] [bbotk]                    10        0.4507898         <list[10]> <list[2]>  14.05608
(preds_at = at$predict(tsk("whas")))
#> <PredictionSurv> for 481 observations:
#>     row_ids time status      crank         lp     distr
#>           1    1   TRUE 13.8682575 13.8682575 <list[1]>
#>           2    1   TRUE 12.2799969 12.2799969 <list[1]>
#>           3    1   TRUE 12.8481684 12.8481684 <list[1]>
#> ---                                                    
#>         479  807  FALSE -0.7482064 -0.7482064 <list[1]>
#>         480   39   TRUE  1.2740611  1.2740611 <list[1]>
#>         481   68   TRUE  1.6589222  1.6589222 <list[1]>
preds_at$score(msr("surv.cindex"))
#> surv.cindex 
#>   0.9055631

Created on 2024-01-08 with reprex v2.0.2

Session info ``` r sessionInfo() #> R version 4.3.2 (2023-10-31) #> Platform: aarch64-apple-darwin20 (64-bit) #> Running under: macOS Sonoma 14.2.1 #> #> Matrix products: default #> BLAS: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRblas.0.dylib #> LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.11.0 #> #> locale: #> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8 #> #> time zone: Europe/Berlin #> tzcode source: internal #> #> attached base packages: #> [1] stats graphics grDevices utils datasets methods base #> #> other attached packages: #> [1] mlr3tuning_0.19.2 paradox_0.11.1 #> [3] mlr3extralearners_0.7.1-9000 mlr3proba_0.5.7 #> [5] mlr3pipelines_0.5.0-2 mlr3_0.17.1 #> #> loaded via a namespace (and not attached): #> [1] utf8_1.2.4 future_1.33.1 generics_0.1.3 #> [4] distr6_1.8.4 lattice_0.22-5 listenv_0.9.0 #> [7] digest_0.6.33 magrittr_2.0.3 evaluate_0.23 #> [10] grid_4.3.2 ooplah_0.2.0 fastmap_1.1.1 #> [13] jsonlite_1.8.8 xgboost_1.7.6.1 Matrix_1.6-4 #> [16] backports_1.4.1 survival_3.5-7 param6_0.2.4 #> [19] fansi_1.0.6 scales_1.3.0 RhpcBLASctl_0.23-42 #> [22] codetools_0.2-19 palmerpenguins_0.1.1 cli_3.6.2 #> [25] rlang_1.1.2 crayon_1.5.2 future.apply_1.11.1 #> [28] parallelly_1.36.0 mlr3viz_0.7.0 splines_4.3.2 #> [31] munsell_0.5.0 reprex_2.0.2 withr_2.5.2 #> [34] yaml_2.3.8 tools_4.3.2 parallel_4.3.2 #> [37] uuid_1.1-1 set6_0.2.6 checkmate_2.3.1 #> [40] dplyr_1.1.4 colorspace_2.1-0 ggplot2_3.4.4 #> [43] globals_0.16.2 bbotk_0.7.3 vctrs_0.6.5 #> [46] R6_2.5.1 lifecycle_1.0.4 fs_1.6.3 #> [49] dictionar6_0.1.3 mlr3misc_0.13.0 pkgconfig_2.0.3 #> [52] pillar_1.9.0 gtable_0.3.4 data.table_1.14.10 #> [55] glue_1.6.2 Rcpp_1.0.11 lgr_0.4.4 #> [58] xfun_0.41 tibble_3.2.1 tidyselect_1.2.0 #> [61] rstudioapi_0.15.0 knitr_1.45 htmltools_0.5.7 #> [64] rmarkdown_2.25 compiler_4.3.2 ```
jemus42 commented 8 months ago

345 Appears to solve the pipelines warning issue but it turns out this only masked a more problematic issue:

> learners$XGBCox$train(tasks$grace)
INFO  [18:12:21.824] [bbotk] Starting to optimize 6 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorCombo> [any=TRUE]'
INFO  [18:12:21.838] [bbotk] Evaluating 1 configuration(s)
INFO  [18:12:21.843] [mlr3] Running benchmark with 3 resampling iterations
INFO  [18:12:21.846] [mlr3] Applying learner 'fixfactors.imputesample.collapsefactors.encode.removeconstants.surv.xgboost' on task 'grace' (iter 1/3)
INFO  [18:12:22.183] [mlr3] Applying learner 'fixfactors.imputesample.collapsefactors.encode.removeconstants.surv.xgboost' on task 'grace' (iter 2/3)
INFO  [18:12:22.527] [mlr3] Applying learner 'fixfactors.imputesample.collapsefactors.encode.removeconstants.surv.xgboost' on task 'grace' (iter 3/3)
Error in check_prediction_data.PredictionDataSurv(pdata) : 
  Assertion on 'pdata$distr' failed: Contains missing values (row 259, col 1).
This happened PipeOp surv.xgboost's $predict()

Leaving this here just as a reminder, as I have not managed to construct a reliable reprex yet (it happens with the benchmark setup though).

While attempting to debug this I found in https://github.com/mlr-org/mlr3proba/blob/ae36e0861d286b820ef792ed208b54585dc14c87/R/breslow.R#L97-L102 there are Inf and NaN values introduced, resulting in cumhaz to contain NA values which then result in the issue.

jemus42 commented 8 months ago

Tried to whittle it down to a reprex but haven't managed to reproduce it outside of the autotuner and my mlr-foo might just not be sufficient here:

(depends on remotes::install_github("mlr-org/mlr3proba#345"))

# remotes::install_github("mlr-org/mlr3proba#345")

library("mlr3")
library("mlr3proba")
library("mlr3learners")
library("mlr3pipelines")
library("mlr3tuning")
#> Loading required package: paradox
requireNamespace("mlr3extralearners")
#> Loading required namespace: mlr3extralearners
# lgr::get_logger("mlr3")$set_threshold("debug")
# lgr::get_logger("bbotk")$set_threshold("debug")

learner = lrn(
  "surv.xgboost",
  tree_method = "hist",
  booster = "gbtree",
  objective = "survival:cox",
  nrounds = 57,
  eta = 0.9687533,
  max_depth = 2
)

graph_learner = ppl("distrcompositor", learner = learner, form = "ph", estimator = "breslow") |>
  as_learner()

set.seed(96)
resampling = rsmp("holdout")
resampling$instantiate(tsk("grace"))

graph_learner$train(tsk("grace"), row_ids = resampling$instance$train)
graph_learner$predict(tsk("grace"), row_ids = resampling$instance$test)
#> Error in check_prediction_data.PredictionDataSurv(pdata): Assertion on 'pdata$distr' failed: Contains missing values (row 65, col 1).
#> This happened PipeOp surv.xgboost's $predict()

Created on 2024-01-10 with reprex v2.0.2

jemus42 commented 8 months ago

Related to https://github.com/dmlc/xgboost/issues/9979