mlr-org / mlr3proba

Probabilistic Learning for mlr3
https://mlr3proba.mlr-org.com/
GNU Lesser General Public License v3.0
130 stars 20 forks source link

Warning when using `breslow` PipeOp with `AutoTuner` #344

Closed jemus42 closed 10 months ago

jemus42 commented 10 months ago

I'm seeing a PipeOp-related issue with the new breslow estimator:

6: PipeOp PipeOpBreslow has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.

The hash and phash of a PipeOp must differ when it represents a different operation; since PipeOpBreslow has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.

While it says to be a warning, it suffices to make the fallback learner kick in in our benchmark, which then capsizes the entire learner unfortunately.

Reprex

library(mlr3)
library(mlr3pipelines)
library(mlr3proba)
library(mlr3extralearners)
library(mlr3tuning)
#> Loading required package: paradox

base_lrn = lrn("surv.xgboost", tree_method = "hist", booster = "gbtree", objective = "survival:cox")

grph_lrn = po("encode", method = "treatment") %>>%
  ppl("distrcompositor", learner = base_lrn, form = "ph", estimator = "breslow") |>
  as_learner()

# Train/predict without AutoTuner seems fine
grph_lrn$train(tsk("whas"))
(preds = grph_lrn$predict(tsk("whas")))
#> <PredictionSurv> for 481 observations:
#>     row_ids time status       crank          lp     distr
#>           1    1   TRUE  1.93616793  1.93616793 <list[1]>
#>           2    1   TRUE  1.21409347  1.21409347 <list[1]>
#>           3    1   TRUE  1.93616793  1.93616793 <list[1]>
#> ---                                                      
#>         479  807  FALSE -0.54297472 -0.54297472 <list[1]>
#>         480   39   TRUE -0.50549181 -0.50549181 <list[1]>
#>         481   68   TRUE -0.02851581 -0.02851581 <list[1]>
preds$score(msr("surv.cindex"))
#> surv.cindex 
#>   0.7956677

at = AutoTuner$new(
  learner = grph_lrn,
  search_space = ps(
    surv.xgboost.nrounds = p_int(10, 20),
    surv.xgboost.eta = p_dbl(0, 1)
  ),
  resampling = rsmp("holdout"),
  measure = msr("surv.rcll"),
  terminator = trm("evals", n_evals = 3),
  tuner = tnr("random_search"),
  store_tuning_instance = FALSE,
  store_benchmark_result = FALSE,
  store_models = FALSE
)

at$train(tsk("whas"))
#> INFO  [16:58:08.210] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=3, k=0]'
#> INFO  [16:58:08.233] [bbotk] Evaluating 1 configuration(s)
#> INFO  [16:58:08.241] [mlr3] Running benchmark with 1 resampling iterations
#> INFO  [16:58:08.259] [mlr3] Applying learner 'encode.surv.xgboost' on task 'whas' (iter 1/1)
#> Warning: PipeOp PipeOpBreslow has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.
#> 
#> The hash and phash of a PipeOp must differ when it represents a different operation; since PipeOpBreslow has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.
#> 
#> This warning will become an error in the future.
#> INFO  [16:58:08.320] [mlr3] Finished benchmark
#> Warning: PipeOp PipeOpBreslow has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.
#> 
#> The hash and phash of a PipeOp must differ when it represents a different operation; since PipeOpBreslow has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.
#> 
#> This warning will become an error in the future.
#> INFO  [16:58:08.337] [bbotk] Result of batch 1:
#> INFO  [16:58:08.338] [bbotk]  surv.xgboost.nrounds surv.xgboost.eta surv.rcll warnings errors
#> INFO  [16:58:08.338] [bbotk]                    10        0.4507898  14.05608        0      0
#> INFO  [16:58:08.338] [bbotk]  runtime_learners
#> INFO  [16:58:08.338] [bbotk]             0.053
#> INFO  [16:58:08.339] [bbotk] Evaluating 1 configuration(s)
#> INFO  [16:58:08.343] [mlr3] Running benchmark with 1 resampling iterations
#> INFO  [16:58:08.345] [mlr3] Applying learner 'encode.surv.xgboost' on task 'whas' (iter 1/1)
#> Warning: PipeOp PipeOpBreslow has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.
#> 
#> The hash and phash of a PipeOp must differ when it represents a different operation; since PipeOpBreslow has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.
#> 
#> This warning will become an error in the future.
#> INFO  [16:58:08.406] [mlr3] Finished benchmark
#> Warning: PipeOp PipeOpBreslow has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.
#> 
#> The hash and phash of a PipeOp must differ when it represents a different operation; since PipeOpBreslow has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.
#> 
#> This warning will become an error in the future.
#> INFO  [16:58:08.425] [bbotk] Result of batch 2:
#> INFO  [16:58:08.425] [bbotk]  surv.xgboost.nrounds surv.xgboost.eta surv.rcll warnings errors
#> INFO  [16:58:08.425] [bbotk]                    19        0.9423549  14.82386        0      0
#> INFO  [16:58:08.425] [bbotk]  runtime_learners
#> INFO  [16:58:08.425] [bbotk]             0.055
#> INFO  [16:58:08.427] [bbotk] Evaluating 1 configuration(s)
#> INFO  [16:58:08.430] [mlr3] Running benchmark with 1 resampling iterations
#> INFO  [16:58:08.433] [mlr3] Applying learner 'encode.surv.xgboost' on task 'whas' (iter 1/1)
#> Warning: PipeOp PipeOpBreslow has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.
#> 
#> The hash and phash of a PipeOp must differ when it represents a different operation; since PipeOpBreslow has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.
#> 
#> This warning will become an error in the future.
#> INFO  [16:58:08.491] [mlr3] Finished benchmark
#> Warning: PipeOp PipeOpBreslow has construction arguments besides 'id' and 'param_vals' but does not overload the private '.additional_phash_input()' function.
#> 
#> The hash and phash of a PipeOp must differ when it represents a different operation; since PipeOpBreslow has construction arguments that could change the operation that is performed by it, it is necessary for the $hash and $phash to reflect this. `.additional_phash_input()` should return all the information (e.g. hashes of encapsulated items) that should additionally be hashed; read the help of ?PipeOp for more information.
#> 
#> This warning will become an error in the future.
#> INFO  [16:58:08.506] [bbotk] Result of batch 3:
#> INFO  [16:58:08.506] [bbotk]  surv.xgboost.nrounds surv.xgboost.eta surv.rcll warnings errors
#> INFO  [16:58:08.506] [bbotk]                    13         0.860655  14.71267        0      0
#> INFO  [16:58:08.506] [bbotk]  runtime_learners
#> INFO  [16:58:08.506] [bbotk]             0.052
#> INFO  [16:58:08.510] [bbotk] Finished optimizing after 3 evaluation(s)
#> INFO  [16:58:08.510] [bbotk] Result:
#> INFO  [16:58:08.510] [bbotk]  surv.xgboost.nrounds surv.xgboost.eta learner_param_vals  x_domain surv.rcll
#> INFO  [16:58:08.510] [bbotk]                    10        0.4507898         <list[10]> <list[2]>  14.05608
(preds_at = at$predict(tsk("whas")))
#> <PredictionSurv> for 481 observations:
#>     row_ids time status      crank         lp     distr
#>           1    1   TRUE 13.8682575 13.8682575 <list[1]>
#>           2    1   TRUE 12.2799969 12.2799969 <list[1]>
#>           3    1   TRUE 12.8481684 12.8481684 <list[1]>
#> ---                                                    
#>         479  807  FALSE -0.7482064 -0.7482064 <list[1]>
#>         480   39   TRUE  1.2740611  1.2740611 <list[1]>
#>         481   68   TRUE  1.6589222  1.6589222 <list[1]>
preds_at$score(msr("surv.cindex"))
#> surv.cindex 
#>   0.9055631

Created on 2024-01-08 with reprex v2.0.2

Session info ``` r sessionInfo() #> R version 4.3.2 (2023-10-31) #> Platform: aarch64-apple-darwin20 (64-bit) #> Running under: macOS Sonoma 14.2.1 #> #> Matrix products: default #> BLAS: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRblas.0.dylib #> LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.11.0 #> #> locale: #> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8 #> #> time zone: Europe/Berlin #> tzcode source: internal #> #> attached base packages: #> [1] stats graphics grDevices utils datasets methods base #> #> other attached packages: #> [1] mlr3tuning_0.19.2 paradox_0.11.1 #> [3] mlr3extralearners_0.7.1-9000 mlr3proba_0.5.7 #> [5] mlr3pipelines_0.5.0-2 mlr3_0.17.1 #> #> loaded via a namespace (and not attached): #> [1] utf8_1.2.4 future_1.33.1 generics_0.1.3 #> [4] distr6_1.8.4 lattice_0.22-5 listenv_0.9.0 #> [7] digest_0.6.33 magrittr_2.0.3 evaluate_0.23 #> [10] grid_4.3.2 ooplah_0.2.0 fastmap_1.1.1 #> [13] jsonlite_1.8.8 xgboost_1.7.6.1 Matrix_1.6-4 #> [16] backports_1.4.1 survival_3.5-7 param6_0.2.4 #> [19] fansi_1.0.6 scales_1.3.0 RhpcBLASctl_0.23-42 #> [22] codetools_0.2-19 palmerpenguins_0.1.1 cli_3.6.2 #> [25] rlang_1.1.2 crayon_1.5.2 future.apply_1.11.1 #> [28] parallelly_1.36.0 mlr3viz_0.7.0 splines_4.3.2 #> [31] munsell_0.5.0 reprex_2.0.2 withr_2.5.2 #> [34] yaml_2.3.8 tools_4.3.2 parallel_4.3.2 #> [37] uuid_1.1-1 set6_0.2.6 checkmate_2.3.1 #> [40] dplyr_1.1.4 colorspace_2.1-0 ggplot2_3.4.4 #> [43] globals_0.16.2 bbotk_0.7.3 vctrs_0.6.5 #> [46] R6_2.5.1 lifecycle_1.0.4 fs_1.6.3 #> [49] dictionar6_0.1.3 mlr3misc_0.13.0 pkgconfig_2.0.3 #> [52] pillar_1.9.0 gtable_0.3.4 data.table_1.14.10 #> [55] glue_1.6.2 Rcpp_1.0.11 lgr_0.4.4 #> [58] xfun_0.41 tibble_3.2.1 tidyselect_1.2.0 #> [61] rstudioapi_0.15.0 knitr_1.45 htmltools_0.5.7 #> [64] rmarkdown_2.25 compiler_4.3.2 ```
jemus42 commented 10 months ago

345 Appears to solve the pipelines warning issue but it turns out this only masked a more problematic issue:

> learners$XGBCox$train(tasks$grace)
INFO  [18:12:21.824] [bbotk] Starting to optimize 6 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorCombo> [any=TRUE]'
INFO  [18:12:21.838] [bbotk] Evaluating 1 configuration(s)
INFO  [18:12:21.843] [mlr3] Running benchmark with 3 resampling iterations
INFO  [18:12:21.846] [mlr3] Applying learner 'fixfactors.imputesample.collapsefactors.encode.removeconstants.surv.xgboost' on task 'grace' (iter 1/3)
INFO  [18:12:22.183] [mlr3] Applying learner 'fixfactors.imputesample.collapsefactors.encode.removeconstants.surv.xgboost' on task 'grace' (iter 2/3)
INFO  [18:12:22.527] [mlr3] Applying learner 'fixfactors.imputesample.collapsefactors.encode.removeconstants.surv.xgboost' on task 'grace' (iter 3/3)
Error in check_prediction_data.PredictionDataSurv(pdata) : 
  Assertion on 'pdata$distr' failed: Contains missing values (row 259, col 1).
This happened PipeOp surv.xgboost's $predict()

Leaving this here just as a reminder, as I have not managed to construct a reliable reprex yet (it happens with the benchmark setup though).

While attempting to debug this I found in https://github.com/mlr-org/mlr3proba/blob/ae36e0861d286b820ef792ed208b54585dc14c87/R/breslow.R#L97-L102 there are Inf and NaN values introduced, resulting in cumhaz to contain NA values which then result in the issue.

jemus42 commented 10 months ago

Tried to whittle it down to a reprex but haven't managed to reproduce it outside of the autotuner and my mlr-foo might just not be sufficient here:

(depends on remotes::install_github("mlr-org/mlr3proba#345"))

# remotes::install_github("mlr-org/mlr3proba#345")

library("mlr3")
library("mlr3proba")
library("mlr3learners")
library("mlr3pipelines")
library("mlr3tuning")
#> Loading required package: paradox
requireNamespace("mlr3extralearners")
#> Loading required namespace: mlr3extralearners
# lgr::get_logger("mlr3")$set_threshold("debug")
# lgr::get_logger("bbotk")$set_threshold("debug")

learner = lrn(
  "surv.xgboost",
  tree_method = "hist",
  booster = "gbtree",
  objective = "survival:cox",
  nrounds = 57,
  eta = 0.9687533,
  max_depth = 2
)

graph_learner = ppl("distrcompositor", learner = learner, form = "ph", estimator = "breslow") |>
  as_learner()

set.seed(96)
resampling = rsmp("holdout")
resampling$instantiate(tsk("grace"))

graph_learner$train(tsk("grace"), row_ids = resampling$instance$train)
graph_learner$predict(tsk("grace"), row_ids = resampling$instance$test)
#> Error in check_prediction_data.PredictionDataSurv(pdata): Assertion on 'pdata$distr' failed: Contains missing values (row 65, col 1).
#> This happened PipeOp surv.xgboost's $predict()

Created on 2024-01-10 with reprex v2.0.2

jemus42 commented 10 months ago

Related to https://github.com/dmlc/xgboost/issues/9979