mlr-org / mlr3proba

Probabilistic Learning for mlr3
https://mlr3proba.mlr-org.com
GNU Lesser General Public License v3.0
126 stars 20 forks source link

Problems with Survival SVMs [discussion] #287

Closed bblodfon closed 2 years ago

bblodfon commented 2 years ago

Hi,

I made some effort to train and tune survival SVMs in a small dataset. Using a simple autotune example, I found out that the SVM survival learner can either fail (some fault with the optimization solvers I think) or get stuck (training never ends, CPU at 100%). I used a lrn('surv.kaplan') as a fallback learner and added a learner$timeout to deal with these issues but I think that this instability is a bad sign for a learner. These issues mostly relate to the choice of type: whenever it's not regression there is a high chance that you will face such issues (C-indexes are close to 0.5 in the example below from using the kaplan estimator). I have seen the SVM learner fail also when type=regression (more sparsely).

I post the following tuning example here so that others benefit from this investigation. Commenting the learner$fallback and learner$timeout lines will lead to the issues I mentioned.

library(mlr3verse)
#> Loading required package: mlr3
library(mlr3proba)
library(survivalsvm)
#> Loading required package: survival

set.seed(42)
task = as_task_surv(x = veteran, time = 'time', event = 'status')
poe = po('encode')
task = poe$train(list(task))[[1]]

train_indxs = sample(seq_len(nrow(veteran)), 120)
test_indxs  = setdiff(seq_len(nrow(veteran)), train_indxs)

learner = lrn('surv.svm',
  type = to_tune(c('regression', 'vanbelle1', 'vanbelle2', 'hybrid')),
  diff.meth = to_tune(c('makediff1', 'makediff2', 'makediff3')),
  gamma.mu = to_tune(ps(
    gamma = p_dbl(1e-03, 10, logscale = TRUE),
    mu    = p_dbl(1e-03, 10, logscale = TRUE, depends = type == 'hybrid'),
    .extra_trafo = function(x, param_set) {
      list(gamma.mu = c(x$gamma, x$mu))
    },
    .allow_dangling_dependencies = TRUE
  )),
  kernel = to_tune(c('lin_kernel', 'add_kernel', 'rbf_kernel', 'poly_kernel'))
)

# saves you from when the learner crashes
learner$fallback = lrn('surv.kaplan')

# saves you from when the learner is stuck
learner$timeout = c('train' = 1, 'predict' = Inf)

#learner$param_set$values$eig.tol  = 1e-03
#learner$param_set$values$conv.tol = 1e-03
#learner$param_set$values$posd.tol = 1e-03
#learner$param_set$values$opt.meth = 'ipop'
#learner$param_set$values$sigf = 2

#generate_design_random(learner$param_set$search_space(), 20)
generate_design_random(learner$param_set$search_space(), 3)$transpose()
#> [[1]]
#> [[1]]$type
#> [1] "hybrid"
#> 
#> [[1]]$diff.meth
#> [1] "makediff3"
#> 
#> [[1]]$kernel
#> [1] "lin_kernel"
#> 
#> [[1]]$gamma.mu
#> [1] 0.01853109 0.97598798
#> 
#> 
#> [[2]]
#> [[2]]$type
#> [1] "vanbelle2"
#> 
#> [[2]]$diff.meth
#> [1] "makediff3"
#> 
#> [[2]]$kernel
#> [1] "add_kernel"
#> 
#> [[2]]$gamma.mu
#> [1] 0.01089036
#> 
#> 
#> [[3]]
#> [[3]]$type
#> [1] "hybrid"
#> 
#> [[3]]$diff.meth
#> [1] "makediff3"
#> 
#> [[3]]$kernel
#> [1] "lin_kernel"
#> 
#> [[3]]$gamma.mu
#> [1] 0.931249 1.488555

ssvm_at = AutoTuner$new(
  learner = learner,
  resampling = rsmp('cv', folds = 5),
  measure = msr('surv.cindex'),
  terminator = trm('evals', n_evals = 10),
  tuner = tnr('random_search'))
ssvm_at$train(task)
#> INFO  [15:25:11.388] [bbotk] Starting to optimize 5 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=10, k=0]' 
#> INFO  [15:25:11.436] [bbotk] Evaluating 1 configuration(s) 
#> INFO  [15:25:11.462] [mlr3] Running benchmark with 5 resampling iterations 
#> INFO  [15:25:11.503] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5) 
#> INFO  [15:25:11.844] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5) 
#> INFO  [15:25:12.144] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5) 
#> INFO  [15:25:12.450] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5) 
#> INFO  [15:25:12.765] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5) 
#> INFO  [15:25:13.232] [mlr3] Finished benchmark 
#> INFO  [15:25:13.261] [bbotk] Result of batch 1: 
#> INFO  [15:25:13.263] [bbotk]        type diff.meth   gamma mu     kernel surv.cindex warnings errors 
#> INFO  [15:25:13.263] [bbotk]  regression      <NA> 1.99383 NA lin_kernel   0.6893636        0      0 
#> INFO  [15:25:13.263] [bbotk]  runtime_learners                                uhash 
#> INFO  [15:25:13.263] [bbotk]             1.597 16669f80-5de6-4c79-a768-08928c934405 
#> INFO  [15:25:13.271] [bbotk] Evaluating 1 configuration(s) 
#> INFO  [15:25:13.289] [mlr3] Running benchmark with 5 resampling iterations 
#> INFO  [15:25:13.293] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5) 
#> INFO  [15:25:16.780] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5) 
#> INFO  [15:25:20.174] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5) 
#> INFO  [15:25:23.868] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5) 
#> INFO  [15:25:27.330] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5) 
#> INFO  [15:25:30.758] [mlr3] Finished benchmark 
#> INFO  [15:25:30.785] [bbotk] Result of batch 2: 
#> INFO  [15:25:30.787] [bbotk]       type diff.meth     gamma mu     kernel surv.cindex warnings errors 
#> INFO  [15:25:30.787] [bbotk]  vanbelle2 makediff2 -3.545851 NA rbf_kernel         0.5        0      5 
#> INFO  [15:25:30.787] [bbotk]  runtime_learners                                uhash 
#> INFO  [15:25:30.787] [bbotk]                NA c98a2b5b-9cb6-4b2d-8669-4aea5e396cde 
#> INFO  [15:25:30.796] [bbotk] Evaluating 1 configuration(s) 
#> INFO  [15:25:30.823] [mlr3] Running benchmark with 5 resampling iterations 
#> INFO  [15:25:30.828] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5) 
#> INFO  [15:25:31.536] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5) 
#> INFO  [15:25:32.367] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5) 
#> INFO  [15:25:33.067] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5) 
#> INFO  [15:25:33.807] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5) 
#> INFO  [15:25:34.653] [mlr3] Finished benchmark 
#> INFO  [15:25:34.679] [bbotk] Result of batch 3: 
#> INFO  [15:25:34.681] [bbotk]    type diff.meth     gamma       mu     kernel surv.cindex warnings errors 
#> INFO  [15:25:34.681] [bbotk]  hybrid makediff1 -6.114898 1.024288 rbf_kernel   0.5238242        0      0 
#> INFO  [15:25:34.681] [bbotk]  runtime_learners                                uhash 
#> INFO  [15:25:34.681] [bbotk]             3.703 3edbcfe6-a153-47d5-b6f2-818d504735b4 
#> INFO  [15:25:34.692] [bbotk] Evaluating 1 configuration(s) 
#> INFO  [15:25:34.714] [mlr3] Running benchmark with 5 resampling iterations 
#> INFO  [15:25:34.719] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5) 
#> INFO  [15:25:38.152] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5) 
#> INFO  [15:25:41.555] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5) 
#> INFO  [15:25:45.237] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5) 
#> INFO  [15:25:48.701] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5) 
#> INFO  [15:25:52.225] [mlr3] Finished benchmark 
#> INFO  [15:25:52.255] [bbotk] Result of batch 4: 
#> INFO  [15:25:52.256] [bbotk]       type diff.meth    gamma mu     kernel surv.cindex warnings errors 
#> INFO  [15:25:52.256] [bbotk]  vanbelle2 makediff2 1.982577 NA rbf_kernel         0.5        0      5 
#> INFO  [15:25:52.256] [bbotk]  runtime_learners                                uhash 
#> INFO  [15:25:52.256] [bbotk]                NA 0f54c084-42ca-4e08-bc4b-818bc7922e5f 
#> INFO  [15:25:52.265] [bbotk] Evaluating 1 configuration(s) 
#> INFO  [15:25:52.286] [mlr3] Running benchmark with 5 resampling iterations 
#> INFO  [15:25:52.292] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5) 
#> INFO  [15:25:55.867] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5) 
#> INFO  [15:25:59.550] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5) 
#> INFO  [15:26:03.482] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5) 
#> INFO  [15:26:07.046] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5) 
#> INFO  [15:26:10.614] [mlr3] Finished benchmark 
#> INFO  [15:26:10.642] [bbotk] Result of batch 5: 
#> INFO  [15:26:10.644] [bbotk]       type diff.meth     gamma mu     kernel surv.cindex warnings errors 
#> INFO  [15:26:10.644] [bbotk]  vanbelle2 makediff2 -3.050726 NA lin_kernel         0.5        0      5 
#> INFO  [15:26:10.644] [bbotk]  runtime_learners                                uhash 
#> INFO  [15:26:10.644] [bbotk]                NA 1d365cc0-c18b-42bc-92b0-fecac6aaac4d 
#> INFO  [15:26:10.653] [bbotk] Evaluating 1 configuration(s) 
#> INFO  [15:26:10.670] [mlr3] Running benchmark with 5 resampling iterations 
#> INFO  [15:26:10.676] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5) 
#> INFO  [15:26:10.932] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5) 
#> INFO  [15:26:11.186] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5) 
#> INFO  [15:26:11.435] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5) 
#> INFO  [15:26:11.676] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5) 
#> INFO  [15:26:11.925] [mlr3] Finished benchmark 
#> INFO  [15:26:11.954] [bbotk] Result of batch 6: 
#> INFO  [15:26:11.955] [bbotk]        type diff.meth     gamma mu     kernel surv.cindex warnings errors 
#> INFO  [15:26:11.955] [bbotk]  regression      <NA> -5.757422 NA lin_kernel   0.6854107        0      0 
#> INFO  [15:26:11.955] [bbotk]  runtime_learners                                uhash 
#> INFO  [15:26:11.955] [bbotk]             1.127 cde21793-0b12-4566-8e8b-2bb756563a27 
#> INFO  [15:26:11.965] [bbotk] Evaluating 1 configuration(s) 
#> INFO  [15:26:11.988] [mlr3] Running benchmark with 5 resampling iterations 
#> INFO  [15:26:11.996] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5) 
#> INFO  [15:26:12.304] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5) 
#> INFO  [15:26:12.608] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5) 
#> INFO  [15:26:12.900] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5) 
#> INFO  [15:26:13.192] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5) 
#> INFO  [15:26:13.475] [mlr3] Finished benchmark 
#> INFO  [15:26:13.503] [bbotk] Result of batch 7: 
#> INFO  [15:26:13.504] [bbotk]        type diff.meth     gamma mu     kernel surv.cindex warnings errors 
#> INFO  [15:26:13.504] [bbotk]  regression      <NA> 0.2568419 NA lin_kernel   0.6893636        0      0 
#> INFO  [15:26:13.504] [bbotk]  runtime_learners                                uhash 
#> INFO  [15:26:13.504] [bbotk]             1.352 f4e9fa94-2b73-4d39-8e67-f864d3a7c71b 
#> INFO  [15:26:13.513] [bbotk] Evaluating 1 configuration(s) 
#> INFO  [15:26:13.531] [mlr3] Running benchmark with 5 resampling iterations 
#> INFO  [15:26:13.536] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5) 
#> INFO  [15:26:14.563] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5) 
#> INFO  [15:26:15.517] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5) 
#> INFO  [15:26:16.501] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5) 
#> INFO  [15:26:17.924] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5) 
#> INFO  [15:26:19.006] [mlr3] Finished benchmark 
#> INFO  [15:26:19.041] [bbotk] Result of batch 8: 
#> INFO  [15:26:19.043] [bbotk]    type diff.meth     gamma       mu     kernel surv.cindex warnings errors 
#> INFO  [15:26:19.043] [bbotk]  hybrid makediff3 -1.907343 -6.24123 add_kernel   0.5645394        0      1 
#> INFO  [15:26:19.043] [bbotk]  runtime_learners                                uhash 
#> INFO  [15:26:19.043] [bbotk]                NA 8ae16fcb-7969-420b-af39-56a0ce68a74c 
#> INFO  [15:26:19.053] [bbotk] Evaluating 1 configuration(s) 
#> INFO  [15:26:19.072] [mlr3] Running benchmark with 5 resampling iterations 
#> INFO  [15:26:19.077] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5) 
#> INFO  [15:26:22.564] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5) 
#> INFO  [15:26:26.089] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5) 
#> INFO  [15:26:29.908] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5) 
#> INFO  [15:26:33.430] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5) 
#> INFO  [15:26:36.925] [mlr3] Finished benchmark 
#> INFO  [15:26:36.953] [bbotk] Result of batch 9: 
#> INFO  [15:26:36.955] [bbotk]       type diff.meth     gamma mu     kernel surv.cindex warnings errors 
#> INFO  [15:26:36.955] [bbotk]  vanbelle2 makediff2 -1.883382 NA lin_kernel         0.5        0      5 
#> INFO  [15:26:36.955] [bbotk]  runtime_learners                                uhash 
#> INFO  [15:26:36.955] [bbotk]                NA 5a765b5a-0741-4f75-95c3-9096c3916b65 
#> INFO  [15:26:36.965] [bbotk] Evaluating 1 configuration(s) 
#> INFO  [15:26:36.983] [mlr3] Running benchmark with 5 resampling iterations 
#> INFO  [15:26:36.988] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5) 
#> INFO  [15:26:37.058] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5) 
#> INFO  [15:26:37.137] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5) 
#> INFO  [15:26:37.210] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5) 
#> INFO  [15:26:37.284] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5) 
#> INFO  [15:26:37.357] [mlr3] Finished benchmark 
#> INFO  [15:26:37.386] [bbotk] Result of batch 10: 
#> INFO  [15:26:37.388] [bbotk]       type diff.meth     gamma mu     kernel surv.cindex warnings errors 
#> INFO  [15:26:37.388] [bbotk]  vanbelle1 makediff1 -5.990032 NA rbf_kernel   0.5337007        0      0 
#> INFO  [15:26:37.388] [bbotk]  runtime_learners                                uhash 
#> INFO  [15:26:37.388] [bbotk]             0.242 d249648b-6561-4155-9347-243b96263347 
#> INFO  [15:26:37.410] [bbotk] Finished optimizing after 10 evaluation(s) 
#> INFO  [15:26:37.410] [bbotk] Result: 
#> INFO  [15:26:37.412] [bbotk]        type diff.meth   gamma mu     kernel learner_param_vals  x_domain 
#> INFO  [15:26:37.412] [bbotk]  regression      <NA> 1.99383 NA lin_kernel          <list[3]> <list[3]> 
#> INFO  [15:26:37.412] [bbotk]  surv.cindex 
#> INFO  [15:26:37.412] [bbotk]    0.6893636

Created on 2022-08-15 by the reprex package (v2.0.1)

RaphaelS1 commented 2 years ago

Heya thanks for raising the issue. To be very honest I've never had success in tuning {survivalsvm} successfully (even outside of this package). It's been buggy for ages and I'm unconvinced by the underlying implementation.

Just looking at your code above some quick comments: 1) I'd always recommend using hybrid and never tune by type as all others are just a special case of hybrid when gamma or mu are 0. 2) I've noticed choice of kernel can affect crashing.

Would you mind experimenting with {survivalsvm} directly and not via mlr3proba to see if the problem persists?

bblodfon commented 2 years ago

Hi Raphael,

Great to find another person who has found survival SVMs unstable. I wouldn't recommend this learner to anyone unless hyperparameters are hand-picked and no proper tuning is applied (which is, well, not nice).


I did some tests with hybrid type while tuning the gamma.mu and kernel and it seems to be the case that the polynomial kernel is the one that causes the issue (but that may depend on the dataset or other things of course, I have no idea). An example hyperparameter configuration that fails is the following:

library(survivalsvm)
#> Loading required package: survival
fit = survivalsvm(Surv(time, status) ~ ., data = veteran, type = 'hybrid',
  gamma.mu = c(0.76, 0.09), diff.meth = 'makediff3',
  kernel  = 'poly_kernel')
#> Error in tcrossprod(K, Dc): non-conformable arguments

Created on 2022-08-21 by the reprex package (v2.0.1)

RaphaelS1 commented 2 years ago

Yup, buggy! I'm going to close the issue here. I don't think we should add a warning to the learner as in reality it will just perform badly and people will choose other learners. You might want to consider opening an issue in survivalsvm though?