tidymodels / agua

Create and evaluate models using 'tidymodels' and 'h2o'
https://agua.tidymodels.org
Other
21 stars 2 forks source link

Add auto_ml h2o engine #25

Closed qiushiyan closed 2 years ago

qiushiyan commented 2 years ago

updated for new api

library(agua)
#> Loading required package: parsnip
library(ggplot2)
library(dplyr, quietly = TRUE, warn.conflicts = FALSE)
h2o_start()

mod <- auto_ml() %>%
  set_engine("h2o",
             max_runtime_secs = 10,
             save_data = TRUE,
             keep_cross_validation_predictions = TRUE,
             seed = 1) %>%
  set_mode("regression")

m <- mod %>%
  fit(mpg ~ ., data = mtcars)

# rank all algorithms by cross validation performance
# workflowsets::rank_results
rank_results_automl(m)
#> # A tibble: 414 × 5
#>    id                                                algor…¹ .metric  mean  rank
#>    <chr>                                             <chr>   <chr>   <dbl> <int>
#>  1 GBM_lr_annealing_selection_AutoML_4_20220623_164… gbm     mae     1.71      1
#>  2 GBM_lr_annealing_selection_AutoML_4_20220623_164… gbm     mean_r… 4.89      1
#>  3 GBM_lr_annealing_selection_AutoML_4_20220623_164… gbm     mse     4.89      1
#>  4 GBM_lr_annealing_selection_AutoML_4_20220623_164… gbm     r2      0.854     1
#>  5 GBM_lr_annealing_selection_AutoML_4_20220623_164… gbm     residu… 4.89      1
#>  6 GBM_lr_annealing_selection_AutoML_4_20220623_164… gbm     rmse    2.11      1
#>  7 GBM_lr_annealing_selection_AutoML_4_20220623_164… gbm     rmsle   0.100     1
#>  8 GBM_grid_1_AutoML_4_20220623_164917_model_53      gbm     mae     1.87      4
#>  9 GBM_grid_1_AutoML_4_20220623_164917_model_53      gbm     mean_r… 5.40      2
#> 10 GBM_grid_1_AutoML_4_20220623_164917_model_53      gbm     mse     5.40      2
#> # … with 404 more rows, and abbreviated variable name ¹​algorithm

# tune::collect_metrics
collect_metrics(m)
#> # A tibble: 414 × 6
#>    id                                        algor…¹ .metric  mean std_err     n
#>    <chr>                                     <chr>   <chr>   <dbl>   <dbl> <int>
#>  1 GBM_lr_annealing_selection_AutoML_4_2022… gbm     mae     1.71   0.292      5
#>  2 GBM_lr_annealing_selection_AutoML_4_2022… gbm     mean_r… 4.89   1.37       5
#>  3 GBM_lr_annealing_selection_AutoML_4_2022… gbm     mse     4.89   1.37       5
#>  4 GBM_lr_annealing_selection_AutoML_4_2022… gbm     r2      0.854  0.0297     5
#>  5 GBM_lr_annealing_selection_AutoML_4_2022… gbm     residu… 4.89   1.37       5
#>  6 GBM_lr_annealing_selection_AutoML_4_2022… gbm     rmse    2.11   0.332      5
#>  7 GBM_lr_annealing_selection_AutoML_4_2022… gbm     rmsle   0.100  0.0121     5
#>  8 GBM_grid_1_AutoML_4_20220623_164917_mode… gbm     mae     1.87   0.266      5
#>  9 GBM_grid_1_AutoML_4_20220623_164917_mode… gbm     mean_r… 5.40   1.27       5
#> 10 GBM_grid_1_AutoML_4_20220623_164917_mode… gbm     mse     5.40   1.27       5
#> # … with 404 more rows, and abbreviated variable name ¹​algorithm
collect_metrics(m, summarize = FALSE)
#> # A tibble: 2,070 × 5
#>    id                                              algor…¹ .metric cv_id .esti…²
#>    <chr>                                           <chr>   <chr>   <chr>   <dbl>
#>  1 GBM_lr_annealing_selection_AutoML_4_20220623_1… gbm     mae     cv_1…   1.76 
#>  2 GBM_lr_annealing_selection_AutoML_4_20220623_1… gbm     mae     cv_2…   0.945
#>  3 GBM_lr_annealing_selection_AutoML_4_20220623_1… gbm     mae     cv_3…   2.28 
#>  4 GBM_lr_annealing_selection_AutoML_4_20220623_1… gbm     mae     cv_4…   1.15 
#>  5 GBM_lr_annealing_selection_AutoML_4_20220623_1… gbm     mae     cv_5…   2.40 
#>  6 GBM_lr_annealing_selection_AutoML_4_20220623_1… gbm     mean_r… cv_1…   5.08 
#>  7 GBM_lr_annealing_selection_AutoML_4_20220623_1… gbm     mean_r… cv_2…   1.58 
#>  8 GBM_lr_annealing_selection_AutoML_4_20220623_1… gbm     mean_r… cv_3…   8.01 
#>  9 GBM_lr_annealing_selection_AutoML_4_20220623_1… gbm     mean_r… cv_4…   2.02 
#> 10 GBM_lr_annealing_selection_AutoML_4_20220623_1… gbm     mean_r… cv_5…   7.78 
#> # … with 2,060 more rows, and abbreviated variable names ¹​algorithm, ²​.estimate

# autoplot methods for plotting cross validation performances
# plot ranking
autoplot(m, type = "rank",
         metric = c("mae", "rmse")) +
  theme(legend.position = "none")

# plot metric value
autoplot(m, type = "metric") +
  theme(legend.position = "none")

# tidy methods, returns leaderboard in tidy format
m_tidy <- tidy(m, n = 5)
m_tidy %>% mutate(
  .predictions = purrr::map(.model, predict, new_data = head(mtcars))
)
#> # A tibble: 5 × 5
#>   id                                          algor…¹ .metric  .model   .predi…²
#>   <chr>                                       <chr>   <list>   <list>   <list>  
#> 1 GBM_lr_annealing_selection_AutoML_4_202206… gbm     <tibble> <fit[+]> <tibble>
#> 2 GBM_grid_1_AutoML_4_20220623_164917_model_… gbm     <tibble> <fit[+]> <tibble>
#> 3 GBM_grid_1_AutoML_4_20220623_164917_model_… gbm     <tibble> <fit[+]> <tibble>
#> 4 GBM_grid_1_AutoML_4_20220623_164917_model_… gbm     <tibble> <fit[+]> <tibble>
#> 5 GBM_grid_1_AutoML_4_20220623_164917_model_… gbm     <tibble> <fit[+]> <tibble>
#> # … with abbreviated variable names ¹​algorithm, ²​.predictions
# extract single candidate model, default to leader
leader <- extract_fit_parsnip(m)
extract_fit_engine(m, m_tidy$id[[2]])
#> Model Details:
#> ==============
#> 
#> H2ORegressionModel: gbm
#> Model ID:  GBM_grid_1_AutoML_4_20220623_164917_model_53 
#> Model Summary: 
#>   number_of_trees number_of_internal_trees model_size_in_bytes min_depth
#> 1              44                       44               14674         6
#>   max_depth mean_depth min_leaves max_leaves mean_leaves
#> 1        11    7.75000         15         27    21.86364
#> 
#> 
#> H2ORegressionMetrics: gbm
#> ** Reported on training data. **
#> 
#> MSE:  0.0215296
#> RMSE:  0.1467297
#> MAE:  0.1023678
#> RMSLE:  0.007417177
#> Mean Residual Deviance :  0.0215296
#> 
#> 
#> 
#> H2ORegressionMetrics: gbm
#> ** Reported on cross-validation data. **
#> ** 5-fold cross-validation on training data (Metrics computed for combined holdout predictions) **
#> 
#> MSE:  5.170959
#> RMSE:  2.273974
#> MAE:  1.847085
#> RMSLE:  0.1064549
#> Mean Residual Deviance :  5.170959
#> 
#> 
#> Cross-Validation Metrics Summary: 
#>                            mean       sd cv_1_valid cv_2_valid cv_3_valid
#> mae                    1.867650 0.595411   1.770470   0.901639   2.323883
#> mean_residual_deviance 5.403489 2.845899   5.722073   1.189838   8.539475
#> mse                    5.403489 2.845899   5.722073   1.189838   8.539475
#> r2                     0.839580 0.066902   0.765536   0.928443   0.782040
#> residual_deviance      5.403489 2.845899   5.722073   1.189838   8.539475
#> rmse                   2.234830 0.715039   2.392086   1.090797   2.922238
#> rmsle                  0.104290 0.027327   0.095074   0.062640   0.116891
#>                        cv_4_valid cv_5_valid
#> mae                      1.967302   2.374955
#> mean_residual_deviance   4.296808   7.269247
#> mse                      4.296808   7.269247
#> r2                       0.873206   0.848677
#> residual_deviance        4.296808   7.269247
#> rmse                     2.072874   2.696154
#> rmsle                    0.111666   0.135181

predict(leader, head(mtcars))
#> # A tibble: 6 × 1
#>   .pred
#>   <dbl>
#> 1  21.0
#> 2  21.0
#> 3  22.8
#> 4  21.2
#> 5  18.7
#> 6  18.2

# variable importance in metalearner, i.e. model importance of base learner
weights <- member_weights(m) %>%
  tidyr::unnest(importance)

weights
#> # A tibble: 360 × 6
#>    ensemble_id                                   rank member algor…¹ type  value
#>    <chr>                                        <int> <chr>  <chr>   <chr> <dbl>
#>  1 StackedEnsemble_AllModels_3_AutoML_4_202206…    24 GBM_g… gbm     rela… 0.713
#>  2 StackedEnsemble_AllModels_3_AutoML_4_202206…    24 GBM_g… gbm     scal… 1    
#>  3 StackedEnsemble_AllModels_3_AutoML_4_202206…    24 GBM_g… gbm     perc… 0.141
#>  4 StackedEnsemble_AllModels_3_AutoML_4_202206…    24 GBM_g… gbm     rela… 0.673
#>  5 StackedEnsemble_AllModels_3_AutoML_4_202206…    24 GBM_g… gbm     scal… 0.944
#>  6 StackedEnsemble_AllModels_3_AutoML_4_202206…    24 GBM_g… gbm     perc… 0.133
#>  7 StackedEnsemble_AllModels_3_AutoML_4_202206…    24 GBM_g… gbm     rela… 0.557
#>  8 StackedEnsemble_AllModels_3_AutoML_4_202206…    24 GBM_g… gbm     scal… 0.781
#>  9 StackedEnsemble_AllModels_3_AutoML_4_202206…    24 GBM_g… gbm     perc… 0.110
#> 10 StackedEnsemble_AllModels_3_AutoML_4_202206…    24 GBM_g… gbm     rela… 0.543
#> # … with 350 more rows, and abbreviated variable name ¹​algorithm

ggplot(weights, aes(algorithm, value)) +
  geom_boxplot() +
  facet_wrap(~ type)

# can join with tibbles from other functions
member_weights(m) %>%
  left_join(
    rank_results_automl(m) %>%
      select(id, .metric, mean, rank),
    by = c("ensemble_id" = "id")
  )
#> # A tibble: 56 × 6
#>    ensemble_id                            rank.x import…¹ .metric    mean rank.y
#>    <chr>                                   <int> <list>   <chr>     <dbl>  <int>
#>  1 StackedEnsemble_AllModels_3_AutoML_4_…     24 <tibble> mae       2.21      26
#>  2 StackedEnsemble_AllModels_3_AutoML_4_…     24 <tibble> mean_r…   7.18      24
#>  3 StackedEnsemble_AllModels_3_AutoML_4_…     24 <tibble> mse       7.18      24
#>  4 StackedEnsemble_AllModels_3_AutoML_4_…     24 <tibble> null_d… 245.         8
#>  5 StackedEnsemble_AllModels_3_AutoML_4_…     24 <tibble> r2        0.273     55
#>  6 StackedEnsemble_AllModels_3_AutoML_4_…     24 <tibble> residu…  45.3       52
#>  7 StackedEnsemble_AllModels_3_AutoML_4_…     24 <tibble> rmse      2.58      26
#>  8 StackedEnsemble_AllModels_3_AutoML_4_…     24 <tibble> rmsle     0.126     26
#>  9 StackedEnsemble_AllModels_2_AutoML_4_…     25 <tibble> mae       1.91      10
#> 10 StackedEnsemble_AllModels_2_AutoML_4_…     25 <tibble> mean_r…   6.54      18
#> # … with 46 more rows, and abbreviated variable name ¹​importance

# refit with additional 30s of training time
m2 <- refit(m, max_runtime_secs = 30)
m2
#> parsnip model object
#> 
#> H2O AutoML Summary: 182 models
#> ==============
#> Leader Algorithm: gbm 
#> Leader ID: GBM_grid_2_AutoML_5_20220623_165004_model_106 
#> 
#> Leaderboard Preview
#>                                                           model_id     rmse
#> 1                    GBM_grid_2_AutoML_5_20220623_165004_model_106 2.061458
#> 2 GBM_lr_annealing_selection_AutoML_4_20220623_164917_select_model 2.155621
#> 3             DeepLearning_grid_4_AutoML_5_20220623_165004_model_1 2.159179
#> 4                    GBM_grid_2_AutoML_5_20220623_165004_model_156 2.217426
#> 5                    GBM_grid_2_AutoML_5_20220623_165004_model_127 2.220473
#> 6                      GBM_grid_2_AutoML_5_20220623_165004_model_1 2.240486
#>        mse      mae      rmsle mean_residual_deviance
#> 1 4.249607 1.597284 0.09762673               4.249607
#> 2 4.646704 1.683702 0.10084137               4.646704
#> 3 4.662054 1.678089 0.10452470               4.662054
#> 4 4.916977 1.758023 0.10578299               4.916977
#> 5 4.930502 1.733195 0.09974655               4.930502
#> 6 5.019779 1.743794 0.10615148               5.019779

Created on 2022-06-23 by the reprex package (v2.0.1)

qiushiyan commented 2 years ago

I will rename rank_automl to rank_results_automl untill the generics PR.

.metric might be due to using dplyr::nest_by previously, I switched to just tidyr::nest and it's a simple list column.

github-actions[bot] commented 1 year ago

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.