slds-lmu / paper_2023_survival_benchmark

Benchmark for Burk et al. (2024)
https://projects.lukasburk.de/survival_benchmark/
GNU General Public License v3.0
4 stars 0 forks source link

Memory issues due to tasks with large number of unique time points #1

Closed jemus42 closed 8 months ago

jemus42 commented 9 months ago

From proba_benchmark created by jemus42: RaphaelS1/proba_benchmark#61

This can be an issue with learners such as SSVM and ranger for example, but not RFSRC with coarsens the time points to a grid internally, as mentioned by Marvin.

This affects the first couple of tasks here, primarily hdfail and child from what I've seen:

task n p unique time points
hdfail 52422 7 10185
child 26574 6 2838
nwtco 4028 5 2767
nafld1 4000 7 2738
flchain 4000 9 2064

(remaining tasks have < 2000 and this is unlikely to be an issue)

For hdfail, we might get away with a simple ceiling call to avoid time == 0, and given the huge range there this might be fine?

# Loading dataset from URL for simplicity as benchmark is in private github repo
hdfail <- readRDS(url("https://dump.jemu.name/survdat/hdfail.rds"))

dim(hdfail)
#> [1] 52422     7

head(hdfail$time)
#> [1] 172.5000 174.2917 174.2917 194.4167 174.2917 172.5000
range(hdfail$time)
#> [1] 1.666667e-01 1.129512e+04
summary(hdfail$time)
#>      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
#>     0.167   265.427   530.125   685.777   923.073 11295.125

length(unique(hdfail$time))
#> [1] 10185

hdfail |>
  dplyr::mutate(
    time_r2 = round(time, 2),
    time_r1 = round(time, 1),
    time_ceil = ceiling(time)
  ) |>
  dplyr::summarize(dplyr::across(dplyr::starts_with("time"), dplyr::n_distinct))
#>    time time_r2 time_r1 time_ceil
#> 1 10185   10185    7354      2030

Are there any smarter solutions? I could apply them to e.g. hdfail and child in our preprocessing before the benchmarking code, that way we don't have to hack in edge-cases in the preprocessing pipeline.

jemus42 commented 9 months ago

Noted that we should use floor() and the bump up time == 0 cases again.

Also, here's an additional overview for hdfail and child:

library(ggplot2)
child <- readRDS(url("https://dump.jemu.name/survdat/child.rds"))

dim(child)
#> [1] 26574     6
summary(child$time)
#>    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
#>   0.003  15.000  15.000  12.231  15.000  15.000
head(sort(table(child$time)))
#> 
#>  0.35 0.441 0.482 0.567 0.611 0.613 
#>     1     1     1     1     1     1
length(unique(child$time))
#> [1] 2838

ggplot(child, aes(x = time)) +
  geom_histogram(bins = 25) +
  theme_minimal()


hdfail <- readRDS(url("https://dump.jemu.name/survdat/hdfail.rds"))
dim(hdfail)
#> [1] 52422     7
summary(hdfail$time)
#>      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
#>     0.167   265.427   530.125   685.777   923.073 11295.125
head(sort(table(hdfail$time)))
#> 
#> 0.166666666666667 0.208333333333333 0.333333333333333             0.375 
#>                 1                 1                 1                 1 
#> 0.416666666666667 0.541666666666667 
#>                 1                 1
length(unique(hdfail$time))
#> [1] 10185

ggplot(hdfail, aes(x = time)) +
  geom_histogram(bins = 25) +
  theme_minimal()

Created on 2023-10-13 with reprex v2.0.2

jemus42 commented 9 months ago

This is another likely candidate but I'm not sure if it's a RAM issue or something else, because on our 250GB RAM workstation it still errors and locally it crashes my R session:

# CIF + hdfail
library(mlr3verse)
library(mlr3proba)

# Loading dataset from URL for simplicity as benchmark is in private github repo
test_datasets <- readRDS(url("https://dump.jemu.name/survdat/hdfail.rds"))

# Quick check:
dim(test_datasets)

# Converting to mlr3 task, stratify by status is done for all tasks
task = as_task_surv(test_datasets, target = "time", event = "status")
task$set_col_roles("status", add_to = "stratum")

# assemble glmnet learner without tuning
learner = lrn("surv.cforest", ntree = 5000)
learner$predict_type = "crank"

lrn_ppl = po("fixfactors") %>>%
  # po("imputesample", affect_columns = selector_type("factor")) %>>%
  po("collapsefactors",
     no_collapse_above_prevalence = 0.05,
     target_level_count = 5) %>>%
  po("encode", method = "treatment") %>>%
  po("removeconstants") %>>%
  ppl("distrcompositor", learner = learner, form = "ph") |>
  as_learner()

lrn_ppl$graph$plot()

lrn_ppl$train(task)
lrn_ppl$predict(task)

mlr3verse::mlr3verse_info()

The error on the workstation log is this:

INFO  [17:43:29.899] [mlr3] Applying learner 'CIF' on task 'hdfail' (iter 1/1)
INFO  [17:43:30.373] [bbotk] Starting to optimize 5 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=1, k=0]'
INFO  [17:43:30.427] [bbotk] Evaluating 1 configuration(s)
INFO  [17:43:30.479] [mlr3] Running benchmark with 1 resampling iterations
INFO  [17:43:30.509] [mlr3] Applying learner 'fixfactors.imputesample.collapsefactors.removeconstants.distrcompositor.kaplan.surv.cforest.compose_distr' on task 'hdfail' (iter 1/1)
INFO  [18:21:34.849] [mlr3] Finished benchmark
Warning in grepl(paste0("^", id, "__", collapse = "|"), names(pars)) :
  TRE pattern compilation error 'Out of memory'
Error in grepl(paste0("^", id, "__", collapse = "|"), names(pars)) :
  invalid regular expression '^WeightDisc1__|^WeightDisc2__|^WeightDisc3__|^WeightDisc4__|^WeightDisc5__|^WeightDisc6__|^WeightDisc7__|^WeightDisc8__|^WeightDisc9__|^WeightDisc10__|^WeightDisc11__|^WeightDisc12__|^WeightDisc13__|^WeightDisc14__|^WeightDisc15__|^WeightDisc16__|^WeightDisc17__|^WeightDisc18__|^WeightDisc19__|^WeightDisc20__|^WeightDisc21__|^WeightDisc22__|^WeightDisc23__|^WeightDisc24__|^WeightDisc25__|^WeightDisc26__|^WeightDisc27__|^WeightDisc28__|^WeightDisc29__|^WeightDisc30__|^WeightDisc31__|^WeightDisc32__|^WeightDisc33__|^WeightDisc34__|^WeightDisc35__|^WeightDisc36__|^WeightDisc37__|^WeightDisc38__|^WeightDisc39__|^WeightDisc40__|^WeightDisc41__|^WeightDisc42__|^WeightDisc43__|^WeightDisc44__|^WeightDisc45__|^WeightDisc46__|^WeightDisc47__|^WeightDisc48__|^WeightDisc49__|^WeightDisc50__|^WeightDisc51__|^WeightDisc52__|^WeightDisc53__|^WeightDisc54__|^WeightDisc55__|^WeightDisc56__|^WeightDisc57__|^WeightDisc58__|^WeightDisc59__|^WeightDisc60__|^WeightDisc61__|^Wei

### [bt]: Job terminated with an exception [batchtools job.id=1193]
### [bt]: Calculation finished!
jemus42 commented 9 months ago

Akritas and ranger have gotten improvements that help with memory usage, but the reprex above concerning surv.cforest still holds and also fails on the cluster. Not sure if there's anything left we can/should do, but the error in question is from here: https://github.com/xoopR/distr6/blob/1854b22b53da6ff09939f99f4f7f4cc0d54b3660/R/Wrapper_VectorDistribution.R#L1157