tidymodels / dials

Tools for creating tuning parameter values
https://dials.tidymodels.org/
Other
113 stars 27 forks source link

transition `tibble()` -> `new_tibble(list())` in `parameters_constr()` #277

Closed simonpcouch closed 1 year ago

simonpcouch commented 1 year ago
library(tidymodels)

parameters_constr() is called twice per resample fit, and twice for every element of grid when tuning hyperparameters.

c0 <- character(0)
set.seed(2023)

bm <- 
  bench::mark(
    total = fit_resamples(linear_reg(), mpg ~ ., bootstraps(mtcars, 100)),
    parameters_constr = replicate(200, parameters_constr(c0, c0, c0, c0, c0, list())),
    check = FALSE
  )
#> Warning: Some expressions had a GC in every iteration; so filtering is
#> disabled.

bm
#> # A tibble: 2 × 6
#>   expression             min   median `itr/sec` mem_alloc `gc/sec`
#>   <bch:expr>        <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
#> 1 total                4.13s    4.13s     0.242    52.4MB     10.2
#> 2 parameters_constr  89.73ms  90.56ms    10.9     369.9KB     14.5

As a percentage of total time, the constructor takes:

100 * as.numeric(bm$median[2]) / as.numeric(bm$median[1])
#> [1] 2.193956

parameters_constr() is only called once in the package, and is used in such a way that we know we won’t need recycling: https://github.com/tidymodels/dials/blob/ec3cd5154bfae1b677c753861ea4330b11330bb9/R/parameters.R#L45-L52

It is true that this function is exported and could be used anywhere, but I think _constr implies that this is for advanced usage only.

The rewritten version of the constructor just switches out tibble(...) for new_tibble(list(...)):

parameters_constr2 <-
  function(name, id, source, component, component_id, object) {
    dials:::chr_check(name)
    dials:::chr_check(id)
    dials:::chr_check(source)
    dials:::chr_check(component)
    dials:::chr_check(component_id)
    dials:::unique_check(id)
    if (is.null(object)) {
      rlang::abort("Element `object` should not be NULL.")
    }
    if (!is.list(object)) {
      rlang::abort("`object` should be a list.")
    }
    is_good_boi <- map_lgl(object, dials:::param_or_na)
    if (any(!is_good_boi)) {
      rlang::abort(
        paste0(
          "`object` values in the following positions should be NA or a ",
          "`param` object:",
          paste0(which(!is_good_boi), collapse = ", ")
        )
      )
    }
    res <-
      tibble::new_tibble(
        list(
          name = name,
          id = id,
          source = source,
          component = component,
          component_id = component_id,
          object = object
        ),
        nrow = length(name)
      )
    class(res) <- c("parameters", class(res))
    res
  }

With benchmarks:

bm2 <- 
  bench::mark(
    old = parameters_constr(c0, c0, c0, c0, c0, list()),
    new = parameters_constr2(c0, c0, c0, c0, c0, list()),
    check = TRUE
  )

Note with check = TRUE in the above, mark() checks equality of the outputs. :) The old constructor is as.numeric(bm2$median[1]) / as.numeric(bm2$median[2]) times slower than the new one:

as.numeric(bm2$median[1]) / as.numeric(bm2$median[2])
#> [1] 11.24562

Created on 2023-03-14 with reprex v2.0.2

hfrick commented 1 year ago

Thank you! 🏎️

github-actions[bot] commented 1 year ago

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.