ixxmu / mp_duty

抓取网络文章到github issues保存
https://archives.duty-machine.now.sh/
103 stars 30 forks source link

mlr3图学习器调参的多种方法 #4720

Closed ixxmu closed 5 months ago

ixxmu commented 5 months ago

https://mp.weixin.qq.com/s/M5SjhiEOmuh8MQ-YrkZ-3A

ixxmu commented 5 months ago

mlr3图学习器调参的多种方法 by 医学和生信笔记




mlr3的使用非常灵活,做一件事情有多种实现方法,虽然每种方法都可以实现目的,但是在使用过程中还是有一些不同之处。

对于刚接触的mlr3的朋友可以先看mlr3入门合集:mlr3教程汇总(有些过时了,可以看看基础部分,实践部分请看2023年及以后的推文,2023年之前的代码会报错)

本次主要给大家演示同时有数据预处理步骤和调参步骤时,如何使用多种方法实现,其实就是演示下如何整合mlr3pipelines的使用,之前介绍过如何使用mlr3pipelines进行数据预处理和特征工程,但是并没有介绍如何和后面的模型拟合等过程连接起来。

在此之前,已经给大家介绍了:

这里要强调一下,mlr3pipelines的功能非常强大,不但是可以做数据预处理和特征工程,结合mlr3独特的图形流,可以实现非常复杂的操作,这篇主要是介绍如何整合它的数据预处理和超参数调优。这个过程和tidymodels中的recipes的操作非常像,大家可以对比学习。

加载R包

rm(list = ls())
library(mlr3verse)
library(mlr3pipelines)
library(magrittr)

一定要熟悉mlr3图形流,如果你还不了解,下面的内容很难理解,可以参考:

方法1

首先选择数据预处理步骤,并使用mlr3独有的管道符%>>%连接预处理步骤和学习器(算法),然后把它们变成一个图学习器(GraphLearner).

# 注意不同的管道符,不要用错了!
glr <- po("encode") %>>%
  po("imputeoor") %>>%
  po("learner",lrn("classif.ranger",predict_type="prob")) %>% 
  as_learner() # 转变为图学习器,这一步很重要!

# 换个简短的名字
glr$id <- "randomForest"
glr
## <GraphLearner:randomForest>
## * Model: -
## * Parameters: encode.method=one-hot, imputeoor.min=TRUE,
##   imputeoor.offset=1, imputeoor.multiplier=1,
##   classif.ranger.num.threads=1
## * Packages: mlr3, mlr3pipelines, stats, mlr3learners, ranger
## * Predict Types:  response, [prob]
## * Feature Types: logical, integer, numeric, character, factor, ordered,
##   POSIXct
## * Properties: featureless, hotstart_backward, hotstart_forward,
##   importance, loglik, missings, multiclass, oob_error,
##   selected_features, twoclass, weights

这个方法和tidymodels中的工作流的概念如出一辙,简直是一模一样,二者结合起来非常好理解而且容易记忆!工作流的概念可参考:tidymodels工作流:workflow

查看可以调节的超参数:

glr$param_set
## <ParamSetCollection>
##                                              id    class lower upper nlevels
##  1:                        classif.ranger.alpha ParamDbl  -Inf   Inf     Inf
##  2:       classif.ranger.always.split.variables ParamUty    NA    NA     Inf
##  3:                classif.ranger.class.weights ParamUty    NA    NA     Inf
##  4:                      classif.ranger.holdout ParamLgl    NA    NA       2
##  5:                   classif.ranger.importance ParamFct    NA    NA       4
##  6:                   classif.ranger.keep.inbag ParamLgl    NA    NA       2
##  7:                    classif.ranger.max.depth ParamInt     0   Inf     Inf
##  8:                   classif.ranger.min.bucket ParamInt     1   Inf     Inf
##  9:                classif.ranger.min.node.size ParamInt     1   Inf     Inf
## 10:                      classif.ranger.minprop ParamDbl  -Inf   Inf     Inf
## 11:                         classif.ranger.mtry ParamInt     1   Inf     Inf
## 12:                   classif.ranger.mtry.ratio ParamDbl     0     1     Inf
## 13:                   classif.ranger.node.stats ParamLgl    NA    NA       2
## 14:            classif.ranger.num.random.splits ParamInt     1   Inf     Inf
## 15:                  classif.ranger.num.threads ParamInt     1   Inf     Inf
## 16:                    classif.ranger.num.trees ParamInt     1   Inf     Inf
## 17:                    classif.ranger.oob.error ParamLgl    NA    NA       2
## 18:        classif.ranger.regularization.factor ParamUty    NA    NA     Inf
## 19:      classif.ranger.regularization.usedepth ParamLgl    NA    NA       2
## 20:                      classif.ranger.replace ParamLgl    NA    NA       2
## 21:    classif.ranger.respect.unordered.factors ParamFct    NA    NA       3
## 22:              classif.ranger.sample.fraction ParamDbl     0     1     Inf
## 23:                  classif.ranger.save.memory ParamLgl    NA    NA       2
## 24: classif.ranger.scale.permutation.importance ParamLgl    NA    NA       2
## 25:                    classif.ranger.se.method ParamFct    NA    NA       2
## 26:                         classif.ranger.seed ParamInt  -Inf   Inf     Inf
## 27:         classif.ranger.split.select.weights ParamUty    NA    NA     Inf
## 28:                    classif.ranger.splitrule ParamFct    NA    NA       3
## 29:                      classif.ranger.verbose ParamLgl    NA    NA       2
## 30:                 classif.ranger.write.forest ParamLgl    NA    NA       2
## 31:                       encode.affect_columns ParamUty    NA    NA     Inf
## 32:                               encode.method ParamFct    NA    NA       5
## 33:                    imputeoor.affect_columns ParamUty    NA    NA     Inf
## 34:                               imputeoor.min ParamLgl    NA    NA       2
## 35:                        imputeoor.multiplier ParamDbl     0   Inf     Inf
## 36:                            imputeoor.offset ParamDbl     0   Inf     Inf
##                                              id    class lower upper nlevels
##            default                   parents   value
##  1:            0.5                                  
##  2: <NoDefault[3]>                                  
##  3:                                                 
##  4:          FALSE                                  
##  5: <NoDefault[3]>                                  
##  6:          FALSE                                  
##  7:                                                 
##  8:              1                                  
##  9:                                                 
## 10:            0.1                                  
## 11: <NoDefault[3]>                                  
## 12: <NoDefault[3]>                                  
## 13:          FALSE                                  
## 14:              1  classif.ranger.splitrule        
## 15:              1                                 1
## 16:            500                                  
## 17:           TRUE                                  
## 18:              1                                  
## 19:          FALSE                                  
## 20:           TRUE                                  
## 21:         ignore                                  
## 22: <NoDefault[3]>                                  
## 23:          FALSE                                  
## 24:          FALSE classif.ranger.importance        
## 25:        infjack                                  
## 26:                                                 
## 27:                                                 
## 28:           gini                                  
## 29:           TRUE                                  
## 30:           TRUE                                  
## 31:  <Selector[1]>                                  
## 32: <NoDefault[3]>                           one-hot
## 33: <NoDefault[3]>                                  
## 34: <NoDefault[3]>                              TRUE
## 35: <NoDefault[3]>                                 1
## 36: <NoDefault[3]>                                 1
##            default                   parents   value

当把预处理步骤和学习器连接在一起时,因为有些预处理步骤也是有超参数的,所以此时mlr3会在超参数前面加上不同的前缀,方便区分是学习器的超参数还是预处理步骤的超参数.缺点就是参数的名字真的好长......

选择超参数网格:

search_space <- ps(
  classif.ranger.max.depth = p_int(lower = 3,upper = 10),
  classif.ranger.min.node.size = p_int(10,300)
)
search_space
## <ParamSet>
##                              id    class lower upper nlevels        default
## 1:     classif.ranger.max.depth ParamInt     3    10       8 <NoDefault[3]>
## 2: classif.ranger.min.node.size ParamInt    10   300     291 <NoDefault[3]>
##    value
## 1:      
## 2:

进行计算、调参:

# 加速
library(future)
plan("multisession",workers=12)

# 减少屏幕输出
lgr::get_logger("mlr3")$set_threshold("warn")
lgr::get_logger("bbotk")$set_threshold("warn")

set.seed(123)
glr_res <- tune(
  tuner = tnr("grid_search", resolution = 5),
  task = tsk("pima"),
  learner = glr,
  resampling = rsmp("holdout"),
  measures = msr("classif.acc"),
  terminator = trm("none"),
  search_space = search_space,
  store_models = T
)

# 查看最好的超参数
glr_res$result_learner_param_vals
## $encode.method
## [1] "one-hot"
## 
## $imputeoor.min
## [1] TRUE
## 
## $imputeoor.offset
## [1] 1
## 
## $imputeoor.multiplier
## [1] 1
## 
## $classif.ranger.num.threads
## [1] 1
## 
## $classif.ranger.max.depth
## [1] 6
## 
## $classif.ranger.min.node.size
## [1] 10
# 查看最好的结果和对应的超参数
as.data.table(glr_res$result)[,c(1,2,5)]
##    classif.ranger.max.depth classif.ranger.min.node.size classif.acc
## 1:                        6                           10        0.75
# 查看每一组超参数的结果
as.data.table(glr_res$archive$data)[,1:3]
##     classif.ranger.max.depth classif.ranger.min.node.size classif.acc
##  1:                        3                          155   0.7226562
##  2:                        3                           82   0.7187500
##  3:                       10                           10   0.7343750
##  4:                        4                          300   0.7265625
##  5:                       10                          155   0.7265625
##  6:                       10                           82   0.7304688
##  7:                        3                          300   0.7226562
##  8:                        3                           10   0.7382812
##  9:                        3                          228   0.7187500
## 10:                        8                           10   0.7343750
## 11:                        6                          228   0.7187500
## 12:                        6                           82   0.7460938
## 13:                        4                           10   0.7421875
## 14:                        4                           82   0.7382812
## 15:                        8                           82   0.7460938
## 16:                        4                          155   0.7187500
## 17:                        8                          228   0.7109375
## 18:                        6                          300   0.7265625
## 19:                        8                          155   0.7304688
## 20:                        6                           10   0.7500000
## 21:                        4                          228   0.7109375
## 22:                        6                          155   0.7226562
## 23:                        8                          300   0.7265625
## 24:                       10                          300   0.7304688
## 25:                       10                          228   0.7265625
##     classif.ranger.max.depth classif.ranger.min.node.size classif.acc

可视化调参结果:

autoplot(glr_res,type = "performance")

把训练好的超参数应用于算法,重新训练、预测:

glr$param_set$values = glr_res$result_learner_param_vals
glr$train(tsk("pima"))$predict(tsk("pima"))
## <PredictionClassif> for 768 observations:
##     row_ids truth response   prob.pos  prob.neg
##           1   pos      pos 0.66378828 0.3362117
##           2   neg      neg 0.13229628 0.8677037
##           3   pos      pos 0.64295310 0.3570469
## ---                                            
##         766   neg      neg 0.13735664 0.8626434
##         767   pos      pos 0.51792572 0.4820743
##         768   neg      neg 0.07776001 0.9222400

方法2

建立auto_tuner,很强大,以后会经常用!

# 建立自动调参器
glr_at <- po("encode") %>>% #注意不同的管道符!
  po("imputeoor") %>>%
  po("learner",lrn("classif.ranger",predict_type="prob")) %>% 
  as_learner() %>% #不要忘记这一步哦~
  auto_tuner(
    tuner = tnr("grid_search", resolution = 5),
    learner = .,
    resampling = rsmp("cv",folds = 3),
    measure = msr("classif.acc"),
    terminator = trm("none"),
    search_space = search_space,
    store_models = T
  )

# 可以改个短一点的名字
glr_at$id <- "glr_autoTuner"

# 执行调参
set.seed(123)
at_res <- glr_at$train(tsk("pima"))

# 查看最好的超参数
at_res$tuning_result$learner_param_vals[[1]]
## $encode.method
## [1] "one-hot"
## 
## $imputeoor.min
## [1] TRUE
## 
## $imputeoor.offset
## [1] 1
## 
## $imputeoor.multiplier
## [1] 1
## 
## $classif.ranger.num.threads
## [1] 1
## 
## $classif.ranger.max.depth
## [1] 8
## 
## $classif.ranger.min.node.size
## [1] 10

自动调参器虽然牛逼,但是这样就不能使用autoplot直接可视化结果了,当然你还是可以自己提取数据画的.

重新建立graph_learner,然后把调好的超参数应用于模型,再次训练、建模:

glr <- po("encode") %>>%
  po("imputeoor") %>>%
  po("learner",lrn("classif.ranger",predict_type="prob")) %>% 
  as_learner()

glr$param_set$values <- at_res$tuning_result$learner_param_vals[[1]]
glr$train(tsk("pima"))$predict(tsk("pima"))
## <PredictionClassif> for 768 observations:
##     row_ids truth response   prob.pos  prob.neg
##           1   pos      pos 0.74703416 0.2529658
##           2   neg      neg 0.10255277 0.8974472
##           3   pos      pos 0.72220748 0.2777925
## ---                                            
##         766   neg      neg 0.10274966 0.8972503
##         767   pos      pos 0.61624926 0.3837507
##         768   neg      neg 0.05230132 0.9476987

以上就是在同时有数据预处理和超参数调优的情况下,进行mlr3调参的基本操作,可以看到再次用到了auto_tuner,真的很好用,以后还会经常用的,大家一定要学会它的用法。

但是变为图学习器之后使用tune()的方式更加符合R语言中的做法,与tidymodels中的调优方式也很相似,方便记忆,而且还有非常多的可视化选项,我在日常使用中还是更喜欢这种方式。

欢迎大家评论区留言或者加入QQ群交流。后台回复mlr3即可获取mlr3推文合集.