mlr-org / mlr3extralearners

Extra learners for use in mlr3.
https://mlr3extralearners.mlr-org.com/
88 stars 48 forks source link

[LRNRQ] Add <xgb.train.surv> from package <survXgboost> #224

Closed MosquitoFan closed 2 years ago

MosquitoFan commented 2 years ago

Algorithm

### Package ### Supported types * [ ] classif * [ ] clust * [ ] dens * [ ] regr * [x] surv ### I have checked that this is not already implemented in * [x] mlr3 * [x] mlr3learners * [x] mlr3extralearners * [x] Other core packages (e.g. mlr3proba, mlr3keras) ### Why do I think this is a useful learner? The xgb.train.surv support the prediction: "distr" by the _predict_ function in survXgboost package, which is necessary when using "surv.graf" (mlr3proba) to train the model and further evaluate model performance in other commonly used packages like pec and riskRegression. ### Further Optional Comments The surv.xgboost in present mlr3extralearners only support prediction: "crank" and "lp", so I think we can add a new learner to predict "distr". Besides the input data and label, other parameters of xgb.train.surv are same to xgb.train (pakcage xgboost, learner surv.xgboost). The input data and label is the data before xgb.Dmatrix which prepared data for xgb.train in package xgboost. The predict function need additional arguments like type = "surv", times = 10 (e.g., 10 years in the label, which can also be replaced with the longest time in test label). (e.g. if this algorithm is already implemented under a different learner then please explain why this additional implementation is useful) I have tried to build the new learner but I have no experience in constructing R package, I always meet the error that the new learner cannot add into the dictionary. So please find the attached R codes of new learners, maybe it would help a little. Best [SXGboost.zip](https://github.com/mlr-org/mlr3extralearners/files/9181542/SXGboost.zip)
sebffischer commented 2 years ago

Hey @MosquitoFan, thank you for your interest in mlr3 and sorry for the late response. So adding a new learner just for a predict type does not make sense I think. Maybewe could extend the current surv.xgboost learner to allow for this predict type. Unfortunately I am not too familiar with the survival learners myself so I would be hesitant to review / implement this.

Fortunately we have a survival specialist in the mlr3 team!

@RaphaelS1 What do you think here? :)

MosquitoFan commented 2 years ago

Hey @MosquitoFan, thank you for your interest in mlr3 and sorry for the late response. So adding a new learner just for a predict type does not make sense I think. Maybewe could extend the current surv.xgboost learner to allow for this predict type. Unfortunately I am not too familiar with the survival learners myself so I would be hesitant to review / implement this.

Fortunately we have a survival specialist in the mlr3 team!

@RaphaelS1 What do you think here? :)

Thank you for your response, I hope RaphaelS1 may be interested in this issue.

May I know if there is any video course on how to add a new learner? I am really interested in learning that!

Best!

RaphaelS1 commented 2 years ago

Hi @MosquitoFan , in general boosting algorithms for survival analysis predict a linear predictor that can then be composed to a full survival distribution after making some assumptions about the probabilistic model form, in mlr3 language this looks like the code below - as you can see we now have a distr output.

As survXgboost just provides the same functionality (and with less control) it does not make sense to me to install it as a new learner or to add it to the existing learner. Therefore I'd suggest you use the code below instead.

library(mlr3proba)
library(mlr3pipelines)
lrn_ph = as_learner(ppl("distrcompositor",
    lrn("surv.xgboost", objective = "survival:cox"), form = "ph"))
lrn_aft = as_learner(ppl("distrcompositor",
    lrn("surv.xgboost", objective = "survival:aft"), form = "aft"))
task = tgen("simsurv")$generate(10)
p_ph = lrn_ph$train(task)$predict(task)
p_aft = lrn_aft$train(task)$predict(task)
p_ph
p_ph$score(msr("surv.graf"))
p_aft$score(msr("surv.graf"))

<PredictionSurv> for 10 observations:
    row_ids     time status      crank
          1 2.219370   TRUE -0.3834336
          2 0.933966   TRUE -0.3834336
          3 2.455688   TRUE -0.3834336
---                                   
          8 4.170562   TRUE -0.9124171
          9 1.060284   TRUE -0.3834336
         10 5.000000  FALSE -0.9124171
            lp     distr
    -0.3834336 <list[1]>
    -0.3834336 <list[1]>
    -0.3834336 <list[1]>
---                     
    -0.9124171 <list[1]>
    -0.3834336 <list[1]>
    -0.9124171 <list[1]>
surv.graf 
0.2140287 
surv.graf 
0.2605627 
MosquitoFan commented 2 years ago

Hi, @RaphaelS1 , Thank you very much for the response, it works for me!!!!

sebffischer commented 2 years ago

For your information: https://mlr3book.mlr-org.com/07-extending-learners.html shows you how to add a learner to mlr3

MosquitoFan commented 2 years ago

Thank you but this web seems 404 can't find.

sebffischer commented 2 years ago

we change the url-structure for the book: https://mlr3book.mlr-org.com/extending.html