tidymodels / dials

Tools for creating tuning parameter values
https://dials.tidymodels.org/
Other
111 stars 26 forks source link

prior_terminal_node_expo hyperparameter has incorrect default range for dbarts #251

Closed n8layman closed 1 year ago

n8layman commented 2 years ago

The problem

The prior_terminal_node_expo (power) hyperparameter default currently has a range of [0,3]. However an exponent less than one here can lead to intensive memory demands due to explosive tree growth. Dbarts documentation even explicitly recommends power > 1:

power: A vector of real numbers greater than one, setting the BART hyperparameter for the tree prior’s growth probability, given by base/(1 + depth)^power.

source pg. 28 under xbart section: dbarts.pdf

Reproducible example

library(tidyverse)

dials::prior_terminal_node_expo()
#> Terminal Node Prior Exponent (quantitative)
#> Range: [0, 3]

n_folds <- 10
grid <- dials::grid_latin_hypercube(size = n_folds,
                                    dials::prior_outcome_range(), # k
                                    dials::prior_terminal_node_expo(), # power
                                    dials::prior_terminal_node_coef(), # base
                                    dials::trees(range = c(25, 300))) |>
  dplyr::rename(k = prior_outcome_range, power = prior_terminal_node_expo, base = prior_terminal_node_coef, n.trees = trees)
grid

 A tibble: 10 × 4
#>       k  power   base n.trees
#>   <dbl>  <dbl>  <dbl>   <int>
#> 1 2.55  0.0803 0.964       61
#> 2 4.47  1.26   0.132      233
#> 3 1.40  1.78   0.863      182
#> 4 1.84  2.66   0.525      255
#> 5 0.194 1.02   0.602      199
#> 6 2.22  1.89   0.766      120
#> 7 4.98  0.387  0.0652     294
#> 8 3.93  2.88   0.212       87
#> 9 0.761 0.724  0.438       39
#> 10 3.38  2.23   0.329      152

# Let's see what the tree depth growth curves look like.
max_depth = 20
depth <- matrix(1:max_depth, nrow = nrow(grid), ncol = max_depth, byrow = T)
depth <- apply(depth, 2, function(v) grid$base / ((1 + v)^grid$power))
grid |> dplyr::bind_cols(depth, .name_repair = ~ c(names(grid),paste0("X", 1:max_depth))) |>
  dplyr::mutate(r = 1:dplyr::n()) |>
  tidyr::pivot_longer(starts_with("X"), names_prefix = "X", values_to = "growth probability", names_to = "depth", names_transform = list(depth = as.integer)) |>
  ggplot2::ggplot(aes(x = depth, y=`growth probability`, col=as.factor(r), group = r)) +
 geom_line()
#> Warning: Removed 5 row(s) containing missing values (geom_path).

tree_growth_test

Created on 2022-09-06 by the reprex package (v2.0.1)

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.0 (2022-04-22) #> os macOS Monterey 12.5.1 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz America/Los_Angeles #> date 2022-09-06 #> pandoc 2.18 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> ! package * version date (UTC) lib source #> P assertthat 0.2.1 2019-03-21 [?] CRAN (R 4.2.0) #> P backports 1.4.1 2021-12-13 [?] CRAN (R 4.2.0) #> broom 1.0.1 2022-08-29 [1] CRAN (R 4.2.0) #> P cellranger 1.1.0 2016-07-27 [?] CRAN (R 4.2.0) #> P cli 3.3.0 2022-04-25 [?] CRAN (R 4.2.0) #> P colorspace 2.0-3 2022-02-21 [?] CRAN (R 4.2.0) #> P crayon 1.5.1 2022-03-26 [?] CRAN (R 4.2.0) #> P curl 4.3.2 2021-06-23 [?] CRAN (R 4.2.0) #> P DBI 1.1.3 2022-06-18 [?] CRAN (R 4.2.0) #> dbplyr 2.2.1 2022-06-27 [1] CRAN (R 4.2.0) #> dials 1.0.0 2022-06-14 [1] CRAN (R 4.2.0) #> P DiceDesign 1.9 2021-02-13 [?] CRAN (R 4.2.0) #> P digest 0.6.29 2021-12-01 [?] CRAN (R 4.2.0) #> dplyr * 1.0.9 2022-04-28 [1] CRAN (R 4.2.0) #> P ellipsis 0.3.2 2021-04-29 [?] CRAN (R 4.2.0) #> P evaluate 0.15 2022-02-18 [?] CRAN (R 4.2.0) #> P fansi 1.0.3 2022-03-24 [?] CRAN (R 4.2.0) #> P farver 2.1.0 2021-02-28 [?] CRAN (R 4.2.0) #> P fastmap 1.1.0 2021-01-25 [?] CRAN (R 4.2.0) #> P forcats * 0.5.1 2021-01-27 [?] CRAN (R 4.2.0) #> P fs 1.5.2 2021-12-08 [?] CRAN (R 4.2.0) #> P gargle 1.2.0 2021-07-02 [?] CRAN (R 4.2.0) #> generics 0.1.3 2022-07-05 [1] CRAN (R 4.2.0) #> ggplot2 * 3.3.6 2022-05-03 [1] CRAN (R 4.2.0) #> P glue 1.6.2 2022-02-24 [?] CRAN (R 4.2.0) #> P googledrive 2.0.0 2021-07-08 [?] CRAN (R 4.2.0) #> P googlesheets4 1.0.0 2021-07-21 [?] CRAN (R 4.2.0) #> P gtable 0.3.0 2019-03-25 [?] CRAN (R 4.2.0) #> hardhat 1.2.0 2022-06-30 [1] CRAN (R 4.2.0) #> haven 2.5.1 2022-08-22 [1] CRAN (R 4.2.0) #> P highr 0.9 2021-04-16 [?] CRAN (R 4.2.0) #> P hms 1.1.1 2021-09-26 [?] CRAN (R 4.2.0) #> P htmltools 0.5.2 2021-08-25 [?] CRAN (R 4.2.0) #> httr 1.4.3 2022-05-04 [1] CRAN (R 4.2.0) #> P jsonlite 1.8.0 2022-02-22 [?] CRAN (R 4.2.0) #> P knitr 1.38 2022-03-25 [?] RSPM (R 4.2.0) #> P labeling 0.4.2 2020-10-20 [?] CRAN (R 4.2.0) #> P lifecycle 1.0.1 2021-09-24 [?] CRAN (R 4.2.0) #> P lubridate 1.8.0 2021-10-07 [?] CRAN (R 4.2.0) #> P magrittr 2.0.3 2022-03-30 [?] CRAN (R 4.2.0) #> P mime 0.12 2021-09-28 [?] CRAN (R 4.2.0) #> P modelr 0.1.8 2020-05-19 [?] CRAN (R 4.2.0) #> P munsell 0.5.0 2018-06-12 [?] CRAN (R 4.2.0) #> pillar 1.8.0 2022-07-18 [1] CRAN (R 4.2.0) #> P pkgconfig 2.0.3 2019-09-22 [?] CRAN (R 4.2.0) #> P purrr * 0.3.4 2020-04-17 [?] CRAN (R 4.2.0) #> P R.cache 0.16.0 2022-07-21 [?] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.0 2022-06-28 [1] CRAN (R 4.2.0) #> P R6 2.5.1 2021-08-19 [?] CRAN (R 4.2.0) #> readr * 2.1.2 2022-01-30 [1] CRAN (R 4.2.0) #> P readxl 1.4.0 2022-03-28 [?] CRAN (R 4.2.0) #> P reprex 2.0.1 2021-08-05 [?] CRAN (R 4.2.0) #> rlang 1.0.4 2022-07-12 [1] CRAN (R 4.2.0) #> P rmarkdown 2.11 2021-09-14 [?] RSPM (R 4.2.0) #> P rstudioapi 0.13 2020-11-12 [?] CRAN (R 4.2.0) #> P rvest 1.0.2 2021-10-16 [?] CRAN (R 4.2.0) #> P scales 1.2.0 2022-04-13 [?] CRAN (R 4.2.0) #> P sessioninfo 1.2.2 2021-12-06 [?] CRAN (R 4.2.0) #> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0) #> P stringr * 1.4.0 2019-02-10 [?] CRAN (R 4.2.0) #> P styler 1.7.0 2022-03-13 [?] CRAN (R 4.2.0) #> tibble * 3.1.8 2022-07-22 [1] CRAN (R 4.2.0) #> P tidyr * 1.2.0 2022-02-01 [?] CRAN (R 4.2.0) #> P tidyselect 1.1.2 2022-02-21 [?] CRAN (R 4.2.0) #> tidyverse * 1.3.2 2022-07-18 [1] CRAN (R 4.2.0) #> tzdb 0.3.0 2022-03-28 [1] CRAN (R 4.2.0) #> P utf8 1.2.2 2021-07-24 [?] CRAN (R 4.2.0) #> P vctrs 0.4.1 2022-04-13 [?] CRAN (R 4.2.0) #> P withr 2.5.0 2022-03-03 [?] CRAN (R 4.2.0) #> xfun 0.32 2022-08-10 [1] CRAN (R 4.2.0) #> P xml2 1.3.3 2021-11-30 [?] CRAN (R 4.2.0) #> P yaml 2.3.5 2022-02-21 [?] CRAN (R 4.2.0) #> #> ────────────────────────────────────────────────────────────────────────────── ```
hfrick commented 2 years ago

Thanks for the issue @n8layman ! I've made a PR to address that.

github-actions[bot] commented 1 year ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.