ixxmu / mp_duty

抓取网络文章到github issues保存
https://archives.duty-machine.now.sh/
103 stars 30 forks source link

R语言lightGBM超参数调优 #4445

Closed ixxmu closed 7 months ago

ixxmu commented 7 months ago

https://mp.weixin.qq.com/s/jICU30gQlWFbtXR7tHjClA

ixxmu commented 7 months ago

R语言lightGBM超参数调优 by 医学和生信笔记

关注公众号,发送R语言python,可获取资料

💡专注R语言在🩺生物医学中的使用


设为“星标”,精彩不错过


上一篇推文介绍了R语言lightgbm快速上手,主要是介绍原生的lightgbm包的使用方法。

LightGBMxgboost、Catboost并称为GBDT三大神器,加上本篇介绍之后就剩Catboost还没介绍过了,等Catboost可用之后我就立马介绍它。这样常见的一些方法我就全都介绍过了,后面就是模型解释的内容了。

在R中对lightgbm进行超参数调优可以通过tidymodels或者mlr3实现,这里简单说下如果通过tidymodels实现。其实这和其他算法并没有什么不同,只不过是换个“引擎”而已。

要想得到更好的结果,关键还是看你对自己数据的理解以及对算法本身的理解。不管是tidymodels还是mlr3都只是一种实现工具而已。

tidymodels中的超参数调优的方法很少,目前只支持网格搜索、贝叶斯优化、模拟退火3种方法。其中模拟退火需要额外安装finetune包才可实现。

注意通过tidymodels实现lightgbm需要额外安装bonsai包。

加载R包和数据

加载R包和数据,这是一个3分类的数据,其中species是结果变量,因子型,其余变量是预测变量:

library(tidymodels)
library(bonsai)
library(modeldata)
library(ggplot2)

data("penguins")
str(penguins)
## tibble [344 × 7] (S3: tbl_df/tbl/data.frame)
##  $ species          : Factor w/ 3 levels "Adelie","Chinstrap",..: 1 1 1 1 1 1 1 1 1 1 ...
##  $ island           : Factor w/ 3 levels "Biscoe","Dream",..: 3 3 3 3 3 3 3 3 3 3 ...
##  $ bill_length_mm   : num [1:344] 39.1 39.5 40.3 NA 36.7 39.3 38.9 39.2 34.1 42 ...
##  $ bill_depth_mm    : num [1:344] 18.7 17.4 18 NA 19.3 20.6 17.8 19.6 18.1 20.2 ...
##  $ flipper_length_mm: int [1:344] 181 186 195 NA 193 190 181 195 193 190 ...
##  $ body_mass_g      : int [1:344] 3750 3800 3250 NA 3450 3650 3625 4675 3475 4250 ...
##  $ sex              : Factor w/ 2 levels "female","male": 2 1 1 NA 1 2 1 2 NA NA ...

之前我们也介绍过,tidymodels支持的模型真的很有限,比如常见的lightGBM它就不支持。。。对于提升树模型,它目前只支持以下5个引擎:

# 加载bonsai之后是7个
show_engines('boost_tree')
## # A tibble: 7 × 2
##   engine   mode          
##   <chr>    <chr>         
## 1 xgboost  classification
## 2 xgboost  regression    
## 3 C5.0     classification
## 4 spark    classification
## 5 spark    regression    
## 6 lightgbm regression    
## 7 lightgbm classification

除此之外,catboost目前还是不支持的,之前的treesnip同时支持catboostlightGBM,但是现在treesnip已经停止维护了,并被bonsai取代,但是bonsai目前并不支持catboost(因为catboost目前并不在CRAN)。官方说大概3个月后支持catboost(现在是20231115)。

模型设定

在这一步我们要设置好哪些超参数需要进行调优,这些超参数和xgboost太像了,所以没详细介绍,大家可以参考xgboost超参数调优

bt_light <- boost_tree(trees = 1000, mtry = tune(), tree_depth = tune(),  
                       learn_rate = tune(), min_n = tune(), 
                       loss_reduction = tune()) %>%
  set_engine("lightgbm",objective = "multiclass",num_class=3) %>%
  set_mode("classification")

bt_light
## Boosted Tree Model Specification (classification)
## 
## Main Arguments:
##   mtry = tune()
##   trees = 1000
##   min_n = tune()
##   tree_depth = tune()
##   learn_rate = tune()
##   loss_reduction = tune()
## 
## Engine-Specific Arguments:
##   objective = multiclass
##   num_class = 3
## 
## Computational engine: lightgbm

建立工作流

接下来是建立工作流,即使是没有任何预处理步骤,也推荐你使用工作流,目前tidymodels在强推工作流,在很多时候确实更加方便,语法也更加统一、好记忆。

bt_wf <- workflow() %>% 
  add_formula(species ~ .) %>% 
  add_model(bt_light)
bt_wf
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Formula
## Model: boost_tree()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## species ~ .
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Boosted Tree Model Specification (classification)
## 
## Main Arguments:
##   mtry = tune()
##   trees = 1000
##   min_n = tune()
##   tree_depth = tune()
##   learn_rate = tune()
##   loss_reduction = tune()
## 
## Engine-Specific Arguments:
##   objective = multiclass
##   num_class = 3
## 
## Computational engine: lightgbm

网格设定

选择一个简单的拉丁超立方网格,这个比传统的网格更有机会覆盖全部的超参数空间:

set.seed(123)
tree_grid <- grid_max_entropy(mtry(range = c(2L6L)),
                          tree_depth(),
                          learn_rate(),
                          min_n(),
                          loss_reduction(),
                          size = 15 # 只产生15个模型配置
                          )
#tree_grid

数据划分

数据划分选择3折交叉验证,我这里并没有划分训练集测试集,有需要的自己分一下即可:

set.seed(123)
bt_folds <- vfold_cv(penguins, v=3)
bt_folds
## #  3-fold cross-validation 
## # A tibble: 3 × 2
##   splits            id   
##   <list>            <chr>
## 1 <split [229/115]> Fold1
## 2 <split [229/115]> Fold2
## 3 <split [230/114]> Fold3

开始调参

set.seed(123)
bt_tune <- tune_grid(bt_wf,
                     bt_folds,
                     grid = tree_grid,
                     control = control_grid(save_pred = T,verbose = F)
                     )
#save(bt_tune,file = "../000机器学习/bt_tune.rdata")

结果探索

查看训练结果:

load(file = "../000机器学习/bt_tune.rdata")

bt_tune %>% collect_metrics()
## # A tibble: 30 × 11
##     mtry min_n tree_depth learn_rate loss_reduction .metric  .estimator  mean
##    <int> <int>      <int>      <dbl>          <dbl> <chr>    <chr>      <dbl>
##  1     4    31          7   1.38e- 8       4.81e- 1 accuracy multiclass 0.442
##  2     4    31          7   1.38e- 8       4.81e- 1 roc_auc  hand_till  0.989
##  3     3    12          1   4.19e-10       1.26e- 5 accuracy multiclass 0.442
##  4     3    12          1   4.19e-10       1.26e- 5 roc_auc  hand_till  0.989
##  5     2    11          5   7.76e- 6       2.64e- 4 accuracy multiclass 0.442
##  6     2    11          5   7.76e- 6       2.64e- 4 roc_auc  hand_till  0.992
##  7     3    10         13   3.60e- 3       2.62e- 3 accuracy multiclass 0.983
##  8     3    10         13   3.60e- 3       2.62e- 3 roc_auc  hand_till  0.997
##  9     5     4          5   1.11e- 7       4.83e-10 accuracy multiclass 0.442
## 10     5     4          5   1.11e- 7       4.83e-10 roc_auc  hand_till  0.988
## # ℹ 20 more rows
## # ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>

结果可视化:

autoplot(bt_tune)
plot of chunk unnamed-chunk-9

查看预测结果:

bt_tune %>% 
  collect_predictions()
## # A tibble: 5,160 × 13
##    id    .pred_Adelie .pred_Chinstrap .pred_Gentoo  .row  mtry min_n tree_depth
##    <chr>        <dbl>           <dbl>        <dbl> <int> <int> <int>      <int>
##  1 Fold1        0.428           0.210        0.362     5     4    31          7
##  2 Fold1        0.428           0.210        0.362     7     4    31          7
##  3 Fold1        0.428           0.210        0.362     8     4    31          7
##  4 Fold1        0.428           0.210        0.362     9     4    31          7
##  5 Fold1        0.428           0.210        0.362    12     4    31          7
##  6 Fold1        0.428           0.210        0.362    13     4    31          7
##  7 Fold1        0.428           0.210        0.362    17     4    31          7
##  8 Fold1        0.428           0.210        0.362    19     4    31          7
##  9 Fold1        0.428           0.210        0.362    22     4    31          7
## 10 Fold1        0.428           0.210        0.362    25     4    31          7
## # ℹ 5,150 more rows
## # ℹ 5 more variables: learn_rate <dbl>, loss_reduction <dbl>,
## #   .pred_class <fct>, species <fct>, .config <chr>

画个3分类的ROC曲线(不会的可参考历史推文R语言多分类ROC曲线绘制):

bt_tune %>% 
  collect_predictions() %>% 
  roc_curve(species, .pred_Adelie:.pred_Gentoo) %>% 
  ggplot(aes(1 - specificity, sensitivity, color = .level)) +
  geom_abline(lty = 2, color = "gray80", linewidth = 1.5) +
  geom_path(alpha = 0.8, linewidth = 1) +
  coord_equal() +
  labs(color = NULL)
plot of chunk unnamed-chunk-11

也可以使用autoplot

bt_tune %>% 
  collect_predictions() %>% 
  roc_curve(species, .pred_Adelie:.pred_Gentoo) %>% 
  autoplot()
plot of chunk unnamed-chunk-12

绘制混淆矩阵:

bt_tune %>% 
  collect_predictions() %>%
  conf_mat(species, .pred_class) %>%
  autoplot()
plot of chunk unnamed-chunk-13

查看表现最好的超参数:

bt_best <- bt_tune %>% 
  select_best("roc_auc")

bt_best
## # A tibble: 1 × 6
##    mtry min_n tree_depth learn_rate loss_reduction .config              
##   <int> <int>      <int>      <dbl>          <dbl> <chr>                
## 1     3    10         13    0.00360        0.00262 Preprocessor1_Model04

重新拟合

选择最好的参数,重新拟合模型

bt_fit <- bt_wf %>% 
  finalize_workflow(bt_best) %>% 
  fit(penguins)
bt_fit
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Formula
## Model: boost_tree()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## species ~ .
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## LightGBM Model (1000 trees)
## Objective: multiclass (3 classes)
## Fitted to dataset with 6 columns

查看变量重要性:

library(vip)
bt_fit %>% 
  extract_fit_parsnip() %>%
  vip()
plot of chunk unnamed-chunk-16

预测新的数据:

new_data <- head(penguins)

predict(bt_fit, new_data = new_data)
## # A tibble: 6 × 1
##   .pred_class
##   <fct>      
## 1 Adelie     
## 2 Adelie     
## 3 Adelie     
## 4 Adelie     
## 5 Adelie     
## 6 Adelie

一套打完,结束!

tidymodels各种用法还不熟悉的可在后台会回复tidymodels获取相关合集。回复lightgbm可获取相关推文合集。



联系我们,关注我们

  1. 免费QQ交流群1:613637742
  2. 免费QQ交流群2:608720452
  3. 公众号消息界面关于作者获取联系方式
  4. 知乎、CSDN、简书同名账号
  5. 哔哩哔哩:阿越就是我

ixxmu commented 7 months ago

mlr干到一切,tidymodels 支持格式太少,尽量避免从它开始