mlr-org / mlr3extralearners

Extra learners for use in mlr3.
https://mlr3extralearners.mlr-org.com/
91 stars 50 forks source link

Time points used in survival learners predicted matrix (`distr`) #387

Closed bblodfon closed 1 month ago

bblodfon commented 1 month ago

Investigation

I performed a small benchmark related to this PR - see reprex below: I wanted to know across all survival mlr3 learners that produce a survival matrix (distr predict type in mlr3proba), which time points are used as columns.

Results

Most survival learners use all the train times points (this plays a large role for computing metrics like eg IBS and making things fair). The different ones are the following:

library(mlr3proba)
#> Loading required package: mlr3
library(mlr3extralearners)

lrn_ids = mlr_learners$keys("^surv")
# remove some learners (DL models, take too much time: bart, mboost has issues, etc.)
lrn_ids = lrn_ids[!grepl(pattern = "blackboost|mboost|deep|pchazard|coxtime|priority|dnn|loghaz|gamboost", lrn_ids)]
# remove learners that don't predict `distr`
lrn_ids = lapply(lrn_ids, function(id) {
  learner = lrn(id)
  if ("distr" %in% learner$predict_types) {
    id
  } else {
    NULL
  }
}) |> unlist()

lrn_ids # ~18 survival learners
#>  [1] "surv.akritas"     "surv.aorsf"       "surv.bart"        "surv.cforest"    
#>  [5] "surv.coxboost"    "surv.coxph"       "surv.ctree"       "surv.cv_coxboost"
#>  [9] "surv.cv_glmnet"   "surv.flexible"    "surv.glmnet"      "surv.kaplan"     
#> [13] "surv.nelson"      "surv.parametric"  "surv.penalized"   "surv.ranger"     
#> [17] "surv.rfsrc"       "surv.xgboost.cox"

task = tsk("gbcs")
set.seed(42)
part = partition(task, ratio = 0.5)

# keep different time points sets to check later
train_times = task$unique_times(part$train)
train_event_times = task$unique_event_times(part$train)

test_times = task$times(part$test)
test_status = task$status(part$test)
test_event_times = sort(unique(test_times[test_status == 1]))
test_times = sort(unique(test_times))

all_times = task$unique_times()
all_event_times = task$unique_event_times()

res = lapply(lrn_ids, function(id) {
  print(id)
  learner = lrn(id)

  if (id == "surv.parametric") {
    learner$param_set$set_values(.values = list(discrete = TRUE))
  }

  if (id == "surv.bart") {
    learner$param_set$set_values(
      # low settings to make computation faster
      .values = list(nskip = 1, ndpost = 3, keepevery = 2, mc.cores = 14)
    )
  }

  if (id == "surv.cforect") {
    learner$param_set$set_values(.values = list(cores = 14))
  }

  if (id == "surv.ranger") {
    learner$param_set$set_values(.values = list(num.threads = 14))
  }

  learner$train(task, part$train)
  p = learner$predict(task, part$test)
  times = as.numeric(colnames(p$data$distr))

  # return discrete times for which we have the predicted S(times)
  times
})
#> [1] "surv.akritas"
#> [1] "surv.aorsf"
#> [1] "surv.bart"
#> [1] "surv.cforest"
#> [1] "surv.coxboost"
#> [1] "surv.coxph"
#> [1] "surv.ctree"
#> [1] "surv.cv_coxboost"
#> [1] "surv.cv_glmnet"
#> [1] "surv.flexible"
#> [1] "surv.glmnet"
#> Warning: Multiple lambdas have been fit. Lambda will be set to 0.01 (see
#> parameter 's').
#> [1] "surv.kaplan"
#> [1] "surv.nelson"
#> [1] "surv.parametric"
#> [1] "surv.penalized"
#> [1] "surv.ranger"
#> [1] "surv.rfsrc"
#> [1] "surv.xgboost.cox"

names(res) = lrn_ids

# example times:
head(res$surv.aorsf)
#> [1]  72 177 210 294 311 323

which_times = lapply(lrn_ids, function(id) {
  times = res[[id]]
  #print(id)

  lgl_list = suppressWarnings(list(
    train = all(times == train_times),
    train_event = all(times == train_event_times),
    test = all(times == test_times),
    test_event = all(times == test_event_times),
    all = all(times == all_times),
    all_Events = all(times == all_event_times)
  ))
I have
  names(which(mlr3misc::map_lgl(lgl_list, isTRUE)))e.g.e.g.
})

names(which_times) = lrn_ids

# Results: which time points are used by each learner in the predicted survival matrix?
which_times
#> $surv.akritas
#> character(0)
#> 
#> $surv.aorsf
#> [1] "test_event"
#> 
#> $surv.bart
#> [1] "train"
#> 
#> $surv.cforest
#> [1] "train"
#> 
#> $surv.coxboost
#> [1] "train"
#> 
#> $surv.coxph
#> [1] "train"
#> 
#> $surv.ctree
#> [1] "train"
#> 
#> $surv.cv_coxboost
#> [1] "train"
#> 
#> $surv.cv_glmnet
#> [1] "train"
#> 
#> $surv.flexible
#> [1] "train"
#> 
#> $surv.glmnet
#> [1] "train"
#> 
#> $surv.kaplan
#> [1] "train"
#> 
#> $surv.nelson
#> [1] "train"
#> 
#> $surv.parametric
#> character(0)
#> 
#> $surv.penalized
#> character(0)
#> 
#> $surv.ranger
#> [1] "train_event"
#> 
#> $surv.rfsrc
#> [1] "train_event"
#> 
#> $surv.xgboost.cox
#> [1] "train"

Created on 2024-09-26 with reprex v2.1.1

bcjaeger commented 1 month ago

Hey John! I think harmonizing is a good idea, and it's much easier to align aorsf with the other learners than aligning the other learners with aorsf. I think my rationale was that evaluating model predictions at the times when events occur should improve efficiency versus evaluating the predictions at times around those points or potentially missing event times in testing data that occur before or after the first or last event time in the training data, respectively. But in most cases I think the event times will be very similar in training versus testing data.

bblodfon commented 1 month ago

See https://github.com/mlr-org/mlr3extralearners/pull/385 for the time point harmonization.

In the code example I now have the 3 RSFs (ranger, aorsf and rfsrc) that provide the unique train event time points, while all the rest of the learners provide the unique train time points for the survival matrix during prediction.

bblodfon commented 1 month ago

https://github.com/mlr-org/mlr3extralearners/pull/385