prior_terminal_node_expo hyperparameter has incorrect default range for dbarts #251

Closed n8layman closed 1 year ago

n8layman commented 2 years ago

The problem

The prior_terminal_node_expo (power) hyperparameter default currently has a range of [0,3]. However an exponent less than one here can lead to intensive memory demands due to explosive tree growth. Dbarts documentation even explicitly recommends power > 1:

power: A vector of real numbers greater than one, setting the BART hyperparameter for the tree prior’s growth probability, given by base/(1 + depth)^power.

source pg. 28 under xbart section: dbarts.pdf

Reproducible example


#> Terminal Node Prior Exponent (quantitative)
#> Range: [0, 3]

n_folds <- 10
grid <- dials::grid_latin_hypercube(size = n_folds,
                                    dials::prior_outcome_range(), # k
                                    dials::prior_terminal_node_expo(), # power
                                    dials::prior_terminal_node_coef(), # base
                                    dials::trees(range = c(25, 300))) |>
  dplyr::rename(k = prior_outcome_range, power = prior_terminal_node_expo, base = prior_terminal_node_coef, n.trees = trees)

 A tibble: 10 × 4
#>       k  power   base n.trees
#>   <dbl>  <dbl>  <dbl>   <int>
#> 1 2.55  0.0803 0.964       61
#> 2 4.47  1.26   0.132      233
#> 3 1.40  1.78   0.863      182
#> 4 1.84  2.66   0.525      255
#> 5 0.194 1.02   0.602      199
#> 6 2.22  1.89   0.766      120
#> 7 4.98  0.387  0.0652     294
#> 8 3.93  2.88   0.212       87
#> 9 0.761 0.724  0.438       39
#> 10 3.38  2.23   0.329      152

# Let's see what the tree depth growth curves look like.
max_depth = 20
depth <- matrix(1:max_depth, nrow = nrow(grid), ncol = max_depth, byrow = T)
depth <- apply(depth, 2, function(v) grid$base / ((1 + v)^grid$power))
grid |> dplyr::bind_cols(depth, .name_repair = ~ c(names(grid),paste0("X", 1:max_depth))) |>
  dplyr::mutate(r = 1:dplyr::n()) |>
  tidyr::pivot_longer(starts_with("X"), names_prefix = "X", values_to = "growth probability", names_to = "depth", names_transform = list(depth = as.integer)) |>
  ggplot2::ggplot(aes(x = depth, y=`growth probability`, col=as.factor(r), group = r)) +
#> Warning: Removed 5 row(s) containing missing values (geom_path).


Created on 2022-09-06 by the reprex package (v2.0.1)

hfrick commented 2 years ago

Thanks for the issue @n8layman ! I've made a PR to address that.

github-actions[bot] commented 1 year ago

