Closed ixxmu closed 9 months ago
💡专注R语言在🩺生物医学中的使用
设为“星标”,精彩不错过
2024年,如果你要做高阶生存分析,也就是各种机器学习算法相关的生存分析,我强烈推荐你学习mlr3
。
如果是1年前我是强烈不推荐的,因为那时的mlr3
发展的还不够成熟,很多语法一直在修改,可能你的代码今天能用,明天就不能用了,而且小bug不断。
但是它的发展非常快速,大家可以去看看我在2022年发的十几篇mlr3
的古早教程,都是看着mlr3book写的,当时连各种语法糖都没有,还是完全的R6语法,每个函数都巨长而且不好理解。再看看现在的mlr3
,真的是天差地别。它们的官方教程mlr3book
我每隔几个月都会去看一看,每次都会发现有更新。从2020年到现在一直在快速更新中!
mlr3
发展到现在(现在是20240102)我个人认为基本很成熟了,应该不会再有大的改动了,各种方法的语法基本都非常完善了,不会再出现“今天的代码明天不能用”的情况。而且对各种生存分析的支持目前在R语言中没有对手!遥遥领先!
本来我是很看好tidymodels
的,毕竟珠玉在前(经典R包caret
),而且还有Rstudio这颗“大树”,但是目前看来它的发展已经远远落后于mlr3
,落后太远了!
所以,如果你要学习各种机器学习相关的生存分析,别再费劲去学各种单独的R包了,直接学习mlr3
吧。(但是一些基础的知识还是要了解一下的,有些东西无法替代,最好是先学习单独的R包,再学习mlr3,这样有些细节你可以心中有数,出了问题方便溯源)
下面我们详细介绍下如何在mlr3
中实现生存分析。
在mlr3
中实现生存分析需要额外安装mlr3proba
和mlr3extralearners
。由于一些原因这两个R包永远不会出现在CRAN中,所以安装方法略有不同,官方提供了多种安装方法,如下:
# 安装mlr3extralearners
remotes::install_github("mlr-org/mlr3extralearners")
# 安装mlr3proba
options(repos=c(
mlrorg = 'https://mlr-org.r-universe.dev',
raphaels1 = 'https://raphaels1.r-universe.dev',
CRAN = 'https://cloud.r-project.org'
))
install.packages("mlr3proba")
# or
install.packages("mlr3proba", repos = "https://mlr-org.r-universe.dev")
然后是加载R包:
library(mlr3verse)
library(mlr3proba)
library(survival)
library(mlr3extralearners)
library(dplyr)
在mlr3
中实现生存分析和做回归分类任务一样,首先是建立任务(其实你也可以先进行一些预处理再建立任务),方法是一模一样的,只不过这里是as_task_surv
而已:
tsk_rats <- as_task_surv(survival::rats, time = "time", event = "status",
type = "right", id = "rats")
tsk_rats
## <TaskSurv:rats> (300 x 5)
## * Target: time, status
## * Properties: -
## * Features (3):
## - int (1): litter
## - dbl (1): rx
## - chr (1): sex
tsk_rats$head()
## time status litter rx sex
## 1: 101 0 1 1 f
## 2: 49 1 1 0 f
## 3: 104 0 1 0 f
## 4: 91 0 2 1 m
## 5: 104 0 2 0 m
## 6: 102 0 2 0 m
注意这个task的Features中有个chr
,这里是不支持的字符型的,你可以先变成因子型再建立任务。
对于这个任务我们可以进行一些简单的可视化,比如生存曲线:
autoplot(tsk_rats, type="target")
分组也是可以的:
autoplot(tsk_rats, rhs = "sex")
还可以更换画图类型等,更加详细的内容这里就不多说了,和分类/回归都是一样的用法,感兴趣的可参考之前的推文。
autoplot(tsk_rats, type="duo")
建立好任务之后就是选择模型,目前mlr3
支持的机器学习生存模型非常多、非常全,遥遥领先!
先给大家展示下目前mlr3
支持的生存分析模型,注意一定要先加载mlr3proba
和mlr3extralearners
:
# 目前(2024.1.2)已经支持33种生存分析模型
# as.data.table(mlr_learners)[grepl("surv",mlr_learners$keys())][,c("key","label","predict_types")]
as.data.table(mlr_learners) %>%
filter(grepl("surv",mlr_learners$keys())) %>%
select(key,label,predict_types)
## key label predict_types
## 1: surv.akritas Akritas Estimator crank,distr
## 2: surv.aorsf Oblique Random Forest crank,distr
## 3: surv.bart Bayesian Additive Regression Trees crank,distr
## 4: surv.blackboost Gradient Boosting crank,distr,lp
## 5: surv.cforest Conditional Random Forest crank,distr
## 6: surv.coxboost Likelihood-based Boosting crank,distr,lp
## 7: surv.coxph Cox Proportional Hazards crank,distr,lp
## 8: surv.coxtime Cox-Time Estimator crank,distr
## 9: surv.ctree Conditional Inference Tree crank,distr
## 10: surv.cv_coxboost Likelihood-based Boosting crank,distr,lp
## 11: surv.cv_glmnet Regularized Generalized Linear Model crank,distr,lp
## 12: surv.deephit Neural Network crank,distr
## 13: surv.deepsurv Neural Network crank,distr
## 14: surv.dnnsurv Neural Network crank,distr
## 15: surv.flexible Flexible Parametric Splines crank,distr,lp
## 16: surv.gamboost Boosted Generalized Additive Model crank,distr,lp
## 17: surv.gbm Gradient Boosting crank,lp
## 18: surv.glmboost Boosted Generalized Linear Model crank,distr,lp
## 19: surv.glmnet Regularized Generalized Linear Model crank,distr,lp
## 20: surv.kaplan Kaplan-Meier Estimator crank,distr
## 21: surv.loghaz Logistic-Hazard Learner crank,distr
## 22: surv.mboost Boosted Generalized Additive Model crank,distr,lp
## 23: surv.nelson Nelson-Aalen Estimator crank,distr
## 24: surv.obliqueRSF Oblique Random Forest crank,distr
## 25: surv.parametric Fully Parametric Learner crank,distr,lp
## 26: surv.pchazard PC-Hazard Learner crank,distr
## 27: surv.penalized Penalized Regression crank,distr
## 28: surv.priority_lasso Priority Lasso lp,response
## 29: surv.ranger Random Forest crank,distr
## 30: surv.rfsrc Random Forest crank,distr
## 31: surv.rpart Survival Tree crank
## 32: surv.svm Support Vector Machine crank,response
## 33: surv.xgboost Gradient Boosting crank,lp
## key label predict_types
这个数量肯定会越来越多,但是目前来说肯定是够用了,常见的它都有,你没听说过的它也有!比如xgboost/随机生存森林/生存支持向量机(这个包本身有些小问题)/coxboost等。
比如选择一个xgboost
模型并查看它的帮助文档:
lrn_surv <- lrn("surv.xgboost")
lrn_surv$help()
这种查看帮助文档的方式是R6语法,大家一定要记住!
这部分内容是mlr3
中做生存分析的核心内容(官网也是近期才更新这么详细的解释…),非常重要。
mlr3
中的生存分析的预测结果可以有4种类型,分别是:
response
: 预测的生存时间distr
: 预测的生存概率分布,连续型或者离散型lp
: 线性预测值crank
: 连续型的风险分数mlr3
中的生存分析模型会直接返回所有模型支持的类型,比如,如果一个模型支持以上4种类型,那么结果会直接返回4种类型的预测结果,如果一个模型支持lp、crank、distr这3种类型,mlr3
也会直接返回这3种类型,不会给你藏着掖着!
比如cox模型,它支持3种类型(不支持生存时间),那么mlr3
会直接返回3种:
# 这里用了自带的数据,和上面的任务有点不一样哦
tsk_rats <- tsk("rats")
#tsk_rats
split <- partition(tsk_rats)
prediction_cph <- lrn("surv.coxph")$train(tsk_rats, split$train)$
predict(tsk_rats, split$test)
head(as.data.table(prediction_cph))
## row_ids time status crank lp distr
## 1: 4 91 FALSE -2.5812496 -2.5812496 <list[1]>
## 2: 7 104 FALSE 0.6233550 0.6233550 <list[1]>
## 3: 13 104 FALSE 0.6336594 0.6336594 <list[1]>
## 4: 16 98 FALSE -2.5606407 -2.5606407 <list[1]>
## 5: 18 77 FALSE -3.4256114 -3.4256114 <list[1]>
## 6: 22 91 FALSE -2.5503363 -2.5503363 <list[1]>
可以看到结果中直接给出了crank
、lp
、distr
这3种类型。
下面详细介绍下这4种类型的预测结果。
response
生存时间(response)在生存分析的预测结果中是最少见的,可能是因为删失的存在,导致很多观测的生存时间不可用。
生存支持向量机是可以预测生存时间的,关于它的详情可参考付费合集;生存支持向量机,细节内容就不多说了,直接上代码:
set.seed(12358)
prediction_svm <- lrn("surv.svm", type = "regression", gamma.mu = 1e-3)$
train(tsk_rats, split$train)$predict(tsk_rats, split$test)
data.frame(pred = prediction_svm$response[1:3],
truth = prediction_svm$truth[1:3])
## pred truth
## 1 91.13665 91+
## 2 90.92749 104+
## 3 90.51497 104+
结果中pred
是预测的生存时间,truth
是真实的生存时间,可以看到预测的时间基本上是小于真实时间的(你的结果可能和我不一样)。但是每个真实时间都是删失的,所以也不好判定这个预测结果的好坏。因此预测生存时间基本上很少用,因为无法判断结果的好坏。
distr
与回归分析中常见的点预测(或者叫确定性预测)不同,生存分析更常见的是分布预测。mlr3proba
中的大多数模型都会默认进行分布预测,这种预测时通过distr
包实现的,并支持生存曲线的可视化。
下面我们提取cox模型的前3个分布预测结果(时间t=77):
prediction_cph$distr[1:3]$survival(77)
## [,1] [,2] [,3]
## 77 0.9947029 0.877308 0.8761194
结果表明前3个样本分别有92.89632%,99.48477%,99.48477%的概率在时间为77时存活。
lp
线性预测值在普通线性回归中是非常好理解的,就是其中的线性部分,但是在各种高阶生存分析中,模型方程基本都不是线性的,所以计算起来也没有那么简单。所以mlr3proba
中的线性预测值其实是风险分数(风险排名,下面要讲的crank)的近似值。
crank
crank
(ContinuousRANKing)是生存分析中非常常见的预测类型,这个通常也是生存分析中预测的风险分数。不同的模型对于这个风险分数的定义是不同的,有的时候crank越大发生终点事件的可能性越大,有的则刚好相反。为了避免混淆,mlr3proba
对这个crank进行了统一,mlr3proba
中的crank统一都是连续型的,而且这个值越大,发生终点事件的可能性越大(值越大风险越高)。
prediction_cph$crank[1:3]
## 1 2 3
## -2.5812496 0.6233550 0.6336594
这个结果说明第2个样本的风险最小(去掉负号比较)
这个结果只能说明谁的风险大谁的风险小,但是这些数字其实是没有任何实际意义的,所以比较不同样本的crank也是没有任何意义的。
目前对于生存分析,mlr3
支持以下评价指标:
# 目前支持24种评价指标
as.data.table(mlr_measures)[grepl("surv",mlr_measures$keys())][,c(1,5)]
## key predict_type
## 1: surv.brier distr
## 2: surv.calib_alpha distr
## 3: surv.calib_beta lp
## 4: surv.chambless_auc lp
## 5: surv.cindex crank
## 6: surv.dcalib distr
## 7: surv.graf distr
## 8: surv.hung_auc lp
## 9: surv.intlogloss distr
## 10: surv.logloss distr
## 11: surv.mae response
## 12: surv.mse response
## 13: surv.nagelk_r2 lp
## 14: surv.oquigley_r2 lp
## 15: surv.rcll distr
## 16: surv.rmse response
## 17: surv.schmid distr
## 18: surv.song_auc lp
## 19: surv.song_tnr lp
## 20: surv.song_tpr lp
## 21: surv.uno_auc lp
## 22: surv.uno_tnr lp
## 23: surv.uno_tpr lp
## 24: surv.xu_r2 lp
## key predict_type
大家最常见的指标,比如brierscore(surv.graf)、auc、cindex等都是支持的,不同的指标需要不同的预测结果类型(如上所示)。
目前对于指标的选择并没有金标准,mlr3
官方推荐对于右删失数据使用right-censored-logloss(msr("surv.rcll")
)评价dsitr
型的预测结果,使用一致性指数(msr("surv.cindex")
)评价模型的区分度,使用D-校准指数(msr("surv.dcalib")
)评价的模型校准度。
下面是使用3种指标评价我们的模型:
prediction_cph$score(msrs(c("surv.rcll", "surv.cindex", "surv.dcalib")))
## surv.rcll surv.cindex surv.dcalib
## 3.7365110 0.8191794 1.5182567
logloss越小越好,是评估模型拟合程度的,单看surv.rcll,这个模型还不错;cindex是越接近1越好,但看cindex,这个模型也是还挺不错的;单看dcalib感觉模型不太行。但是如果不和基线模型(通常是Kaplan-Meier法)比较的话很难说某个模型到底是好还是坏。
mlr3proba
中的预测结果其实是分为native
和composed
的,native是模型本身的预测结果,composed是经过转换的,比如crank这个结果就是经过转换的(因为不同模型的crank代表的意思是不一样的,经过转换之后的意义就一样了,即crank越大风险越高),mlr3proba
会自动进行这一步,具体细节我们就不说了。
native和composed这两种预测结果是可以互相转换的,这个是通过compositor
这类管道操作实现的,目前最常见的是以下两种:
pipeline_crankcompositor()
:把distr
转换为crank
pipeline_distrcompositor()
:把lp
转换为distr
但是在实际使用时,第1种操作其实mlr3
默认会自动为我们进行,无需手动实现(除非你想重新再转换一下)。第2种操作才是我们会经常用到的。
比如梯度提升机GBM模型只支持lp
和crank
这两种类型,我们需要得到distr
这种类型的预测结果,就可以通过distrcompositor
实现。
首先看下默认的GBM模型是没有distr
这种预测结果的(mlr3book官网用的例子是正则化COX回归,但是目前版本的正则化COX回归已经支持distr这种预测类型了):
tsk_rats <- tsk("rats")$select(c("litter", "rx"))
split <- partition(tsk_rats)
learner <- lrn("surv.gbm")
# 没有distr
learner$train(tsk_rats, split$train)$predict(tsk_rats, split$test)
## <PredictionSurv> for 99 observations:
## row_ids time status crank lp
## 9 104 FALSE -1.3179129 -1.3179129
## 11 104 FALSE -1.3179129 -1.3179129
## 12 102 FALSE -1.3179129 -1.3179129
## ---
## 249 66 TRUE -0.1916111 -0.1916111
## 253 92 TRUE 0.3756322 0.3756322
## 297 79 TRUE -0.7755802 -0.7755802
下面我们选择GBM模型,然后指定基线评估方法为Kaplan-Meier法,并且假设我们的预测分布是比例风险的形式,使用distrcompositor
这个管道操作,即可让GBM模型支持distr
类型的输出:
graph_learner <- as_learner(ppl(
"distrcompositor",
learner = learner,
estimator = "kaplan",
form = "ph"
))
# 现在有distr了
graph_learner$train(tsk_rats, split$train)$predict(tsk_rats, split$test)
## <PredictionSurv> for 99 observations:
## row_ids time status crank lp distr
## 9 104 FALSE -1.0502036 -1.0502036 <list[1]>
## 11 104 FALSE -1.0502036 -1.0502036 <list[1]>
## 12 102 FALSE -1.0502036 -1.0502036 <list[1]>
## ---
## 249 66 TRUE -0.2894398 -0.2894398 <list[1]>
## 253 92 TRUE 0.2619725 0.2619725 <list[1]>
## 297 79 TRUE -0.7371226 -0.7371226 <list[1]>
是不是很牛逼呢?
各种转换都可以,这样就方便不同模型的比较了,不存在你有我没有的情况。
下面是一个多模型比较的例子,由于不同模型的预测类型可以互相转换,因此我们可以很方便的进行多模型的比较。
我们选择GBM模型、Kaplan-Meier法、cox模型进行比较,使用3折交叉验证:
tsk_grace <- tsk("grace")
tsk_grace$filter(sample(tsk_grace$nrow, 500))
msr_txt <- c("surv.rcll", "surv.cindex", "surv.dcalib")
measures <- msrs(msr_txt)
# 图学习器
graph_learner <- as_learner(ppl(
"distrcompositor",
learner = lrn("surv.gbm"),
estimator = "kaplan",
form = "ph"
))
graph_learner$id <- "gbm"
learners <- c(lrns(c("surv.coxph", "surv.kaplan")), graph_learner)
set.seed(1258)
bmr <- benchmark(benchmark_grid(tsk_grace, learners,
rsmp("cv", folds = 3)))
## INFO [13:55:07.774] [mlr3] Running benchmark with 9 resampling iterations
## INFO [13:55:07.818] [mlr3] Applying learner 'surv.coxph' on task 'grace' (iter 1/3)
## INFO [13:55:07.837] [mlr3] Applying learner 'surv.coxph' on task 'grace' (iter 2/3)
## INFO [13:55:07.851] [mlr3] Applying learner 'surv.coxph' on task 'grace' (iter 3/3)
## INFO [13:55:07.864] [mlr3] Applying learner 'surv.kaplan' on task 'grace' (iter 1/3)
## INFO [13:55:07.877] [mlr3] Applying learner 'surv.kaplan' on task 'grace' (iter 2/3)
## INFO [13:55:08.131] [mlr3] Applying learner 'surv.kaplan' on task 'grace' (iter 3/3)
## INFO [13:55:08.145] [mlr3] Applying learner 'gbm' on task 'grace' (iter 1/3)
## INFO [13:55:08.212] [mlr3] Applying learner 'gbm' on task 'grace' (iter 2/3)
## INFO [13:55:08.278] [mlr3] Applying learner 'gbm' on task 'grace' (iter 3/3)
## INFO [13:55:08.348] [mlr3] Finished benchmark
bmr$aggregate(measures)[, c("learner_id", ..msr_txt)]
## learner_id surv.rcll surv.cindex surv.dcalib
## 1: surv.coxph 5.075566 0.8409297 7.071879
## 2: surv.kaplan 5.304874 0.5000000 4.431708
## 3: gbm 10.263326 0.8560972 10.000000
结果表明,gbm模型的区分度最好(surv.rcll最高),Kaplan-Meier法的校准度最好(surv.dcalib最低),coxph和gbm的模型准确性(模型拟合程度)差不多(cindex差不多)。
遥遥领先!强烈推荐,大家快去学!
公众号后台回复mlr3即可获取相关推文合集。
联系我们,关注我们
免费QQ交流群1:613637742 免费QQ交流群2:608720452 公众号消息界面关于作者获取联系方式 知乎、CSDN、简书同名账号 哔哩哔哩:阿越就是我
https://mp.weixin.qq.com/s/Js7_axZ3DIkHupoYpACZOA