ixxmu / mp_duty

抓取网络文章到github issues保存
https://archives.duty-machine.now.sh/
110 stars 30 forks source link

tidymodels实现lasso回归及超参数调优 #4777

Closed ixxmu closed 6 months ago

ixxmu commented 6 months ago

https://mp.weixin.qq.com/s/UdGPhQ3tvkvpUHuPmLgOVA

ixxmu commented 6 months ago

tidymodels实现lasso回归及超参数调优 by 医学和生信笔记




之前已经给大家介绍过使用glmnet实现lasso回归了,详情可见:使用glmnet做lasso回归

今天给大家演示下如何使用tidymodels实现lasso回归。

整理数据的过程略,大家可以直接去粉丝QQ群文件下载整理好的数据(别问我怎么加群,加了群也别问我数据在哪里!)。

今天使用的数据是一个回归数据,结果变量是imdb_rating评分,是数值型的,其余都是预测变量,数据一共有136行,32列。

加载R包和数据

加载数据,并分为训练集和测试集,比例为3/4:

library(tidyverse)
library(tidymodels)
load(file = "office.rdata")

office_split <- initial_split(office, strata = season)
office_train <- training(office_split)
office_test <- testing(office_split)

预处理

首先处理下episode_name这个变量,它既不是预测变量,也不是结果变量,但是我们需要它留在数据集中,所以为了不影响建模过程,给它分配一个新的角色:ID

去掉零方差变量,再进行中心化和标准化:

office_rec <- recipe(imdb_rating ~ ., data = office_train) %>%
  update_role(episode_name, new_role = "ID") %>%
  step_zv(all_numeric(), -all_outcomes()) %>%
  step_normalize(all_numeric(), -all_outcomes())

office_prep <- office_rec %>%
  prep(strings_as_factors = FALSE)

建模

我们通过tidymodels实现lasso回归,其实还是要借助glmnet包的,因为tidymodels本身并不能实现任何算法,它只是提供一个统一的接口而已,这一点我们在之前也详细介绍过:tidymodels之parsnip的强大之处

glmnet虽然强大,但是它的一些参数非常的不人性化,tidymodels根据每个参数的意义对这些参数进行了重新命名。

tidymodels中的mixture就是glmnet中的alpha,所以mixture=1就是使用lasso回归。

如果要学习glmnet的用法,我之前也详细介绍过:使用glmnet做lasso回归

我们先自己指定一个惩罚值(这里是0.1)看看效果如何:

lasso_spec <- linear_reg(penalty = 0.1, mixture = 1) %>%
  set_engine("glmnet")

wf <- workflow() %>%
  add_recipe(office_rec)

lasso_fit <- wf %>%
  add_model(lasso_spec) %>%
  fit(data = office_train)

这样模型就拟合好了,如何查看此时的变量系数呢?既然使用了tidymodels,就要了解它的游戏规则,当然是提取出来了。不了解tidymodels的后台回复tidymodels即可获取合集链接。

lasso_fit %>%
  pull_workflow_fit() %>%
  tidy()

## # A tibble: 31 × 3
##    term        estimate penalty
##    <chr>          <dbl>   <dbl>
##  1 (Intercept)   8.36       0.1
##  2 season        0          0.1
##  3 episode       0          0.1
##  4 andy          0          0.1
##  5 angela        0          0.1
##  6 darryl        0          0.1
##  7 dwight        0          0.1
##  8 jim           0.0283     0.1
##  9 kelly         0          0.1
## 10 kevin         0          0.1
## # ℹ 21 more rows

一目了然。

但是用过glmnet的朋友都知道,它会自动使用100个惩罚值建模,然后你可以从中选择最好的。这个过程有点类似于超参数调优哦~

下面我们在tidymodels中实现对惩罚值的调优,也就是找到最适合的惩罚值(即正则化大小)。

超参数调优

下面是设定超参数范围的过程,我之前也是详细介绍过了,基本上我的推文都有一个大致的顺序,一般是从简单到复杂,从一般到特殊,前后具有一定的连贯性,所以如果不了解tidymodels并且也没看过之前的推文的小伙伴可能看不太懂,因为解释很少,这时大家可以先去翻看之前的基础介绍,比如这里的超参数调优之前就介绍过了:

公众号后台回复tidymodels即可获取推文合集,一些知识点可翻阅既往推文。

set.seed(1234)
office_boot <- bootstraps(office_train, strata = season)

tune_spec <- linear_reg(penalty = tune(), mixture = 1) %>%
  set_engine("glmnet")

# 简单的规则网格
#lambda_grid <- grid_regular(penalty(), levels = 100)

# 使用空间填充设计:拉丁超立方
set.seed(1234)
lambda_grid <- grid_latin_hypercube(penalty(), size = 100)

然后就是调优过程了,我们这个数据集非常小,所以也不用并行化:

set.seed(2020)
lasso_grid <- tune_grid(
  wf %>% add_model(tune_spec),
  resamples = office_boot,
  grid = lambda_grid,
  control = control_grid(verbose = FALSE, save_pred = T)
)

查看模型表现:

lasso_grid %>%
  collect_metrics()
## # A tibble: 200 × 7
##     penalty .metric .estimator  mean     n std_err .config               
##       <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                 
##  1 1.13e-10 rmse    standard   0.570    25  0.0115 Preprocessor1_Model001
##  2 1.13e-10 rsq     standard   0.110    25  0.0129 Preprocessor1_Model001
##  3 1.41e-10 rmse    standard   0.570    25  0.0115 Preprocessor1_Model002
##  4 1.41e-10 rsq     standard   0.110    25  0.0129 Preprocessor1_Model002
##  5 1.94e-10 rmse    standard   0.570    25  0.0115 Preprocessor1_Model003
##  6 1.94e-10 rsq     standard   0.110    25  0.0129 Preprocessor1_Model003
##  7 2.16e-10 rmse    standard   0.570    25  0.0115 Preprocessor1_Model004
##  8 2.16e-10 rsq     standard   0.110    25  0.0129 Preprocessor1_Model004
##  9 3.06e-10 rmse    standard   0.570    25  0.0115 Preprocessor1_Model005
## 10 3.06e-10 rsq     standard   0.110    25  0.0129 Preprocessor1_Model005
## # ℹ 190 more rows

可视化

lasso回归肯定是需要把结果画出来的,在tidymodels中我们需要提取数据再画图,需要一定的R语言基础。但是如果你了解tidymodelsggplot2也是非常简单。

lasso_grid %>%
  collect_metrics() %>%
  ggplot(aes(penalty, mean, color = .metric)) +
  geom_errorbar(aes(
    ymin = mean - std_err,
    ymax = mean + std_err
  ),
  alpha = 0.5,width=0.2,linewidth=1.2
  ) +
  geom_line(size = 1.5) +
  facet_wrap(~.metric, scales = "free", nrow = 2) +
  scale_x_log10() +
  theme(legend.position = "none")
plot of chunk unnamed-chunk-9

这张图同时展示了rmsersp两个指标随惩罚值的变化,具体的解释这里就不多介绍了,可以参考之前的推文:使用glmnet做lasso回归

确定最终模型

如何确定最好的惩罚值呢?tidymodels也为我们提供了几个好用的函数:

查看最好的几个结果:

lasso_grid %>% show_best("rmse")
## # A tibble: 5 × 7
##   penalty .metric .estimator  mean     n std_err .config               
##     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                 
## 1  0.0497 rmse    standard   0.475    25  0.0117 Preprocessor1_Model087
## 2  0.0587 rmse    standard   0.476    25  0.0120 Preprocessor1_Model088
## 3  0.0687 rmse    standard   0.476    25  0.0120 Preprocessor1_Model089
## 4  0.0322 rmse    standard   0.479    25  0.0110 Preprocessor1_Model086
## 5  0.0998 rmse    standard   0.480    25  0.0115 Preprocessor1_Model090

lasso_grid %>% show_best("rsq")
## # A tibble: 5 × 7
##   penalty .metric .estimator  mean     n std_err .config               
##     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                 
## 1  0.0198 rsq     standard   0.123    25  0.0147 Preprocessor1_Model083
## 2  0.0206 rsq     standard   0.123    25  0.0148 Preprocessor1_Model084
## 3  0.0132 rsq     standard   0.122    25  0.0144 Preprocessor1_Model082
## 4  0.0278 rsq     standard   0.121    25  0.0152 Preprocessor1_Model085
## 5  0.0116 rsq     standard   0.121    25  0.0141 Preprocessor1_Model081

选择最好的结果,有点类似于选择glmnet中的lambda.1se/lambda.min

lasso_grid %>% select_best("rmse")
## # A tibble: 1 × 2
##   penalty .config               
##     <dbl> <chr>                 
## 1  0.0497 Preprocessor1_Model087

我们选择rmse最小的惩罚值,重新建立最终的模型:

lowest_rmse <- lasso_grid %>%
  select_best("rmse", maximize = FALSE)

final_lasso <- finalize_workflow(
  wf %>% add_model(tune_spec),
  lowest_rmse
)

测试集

把最终的模型重新在训练集进行拟合并查看在测试集的表现:

final_lasso %>% last_fit(office_split) %>%
  collect_metrics()
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard       0.470 Preprocessor1_Model1
## 2 rsq     standard       0.336 Preprocessor1_Model1

搞定!

大家觉得和glmnet比起来更喜欢哪种方式呢?

我是觉得各有千秋吧,如果只是单纯的想做一下lasso,我觉得还是直接用glmnet更简单一些,如果是需要做多个模型,那肯定还是tidymodels更合适一些。