Closed vpnagraj closed 1 year ago
@dwill023 im going to leave myself assigned here too ... but please get started on this one.
for context: we had been using the evaluate_models()
function from trendeval
to identify the "best fit" model automatically from a list of models:
https://github.com/signaturescience/fiphde/blob/main/R/glm.R#L26-L33
you dont have to dive too deep in the weeds on our API for now.
the first question you should try to address is ... what is the comparable functionality in trendeval
now that evaluate_models()
is no longer in the package?
a few breadcrumbs to get started:
fiphde
. devtools::install_version()
can help you get the two different versionsevaluate_models()
and the new API and returns the same kinds of metrics) and dont feel like you need to push code to this repository. if you can put together a reproducible example of both APIs you could post the example code / output in this thread. that would get us well underway and then we can operationalize how to change our API from there. but we need to start with that example of the new API.
New API:
x <- rnorm(100, mean = 0)
y <- rpois(n = 100, lambda = exp(x + 1))
dat <- data.frame(x = x, y = y)
models <- list(
poisson_model = trending::glm_model(y ~ x, poisson),
linear_model = trending::lm_model(y ~ x)
)
res <- evaluate_resampling(models, dat)
res
#Rows: 200
#Columns: 10
#$ model_name <chr> "poisson_model", "poisson_model", "poisson_model", "poisson_model", "poisson_model", "poisson_model", "poisson_model…
#$ metric <chr> "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rms…
#$ result <dbl> 3.567591748, 1.221492528, 0.597965961, 1.604374413, 0.255587393, 2.483827026, 1.230081306, 0.350645465, 0.004610133,…
#$ warnings <list> <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NU…
#$ errors <list> <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NU…
#$ model <list> <Untrained trending model:, glm(formula = y ~ x, family = poisson)>, <Untrained trending model:, glm(formul…
#$ fitting_warnings <tibble[,1]> <tbl_df[48 x 1]>
#$ fitting_errors <tibble[,1]> <tbl_df[48 x 1]>
#$ predicting_warnings <tibble[,1]> <tbl_df[48 x 1]>
#$ predicting_errors <tibble[,1]> <tbl_df[48 x 1]>
best_by_rmse <-
res %>%
# dplyr::filter(purrr::map_lgl(warning, is.null)) %>% # remove models that gave warnings
dplyr::filter(purrr::map_lgl(errors, is.null)) %>% # remove models that errored
dplyr::slice_min(result) %>%
dplyr::select(model) %>%
purrr::pluck(1,1)
best_by_rmse
# Untrained trending model:
# glm(formula = y ~ x, family = poisson)
best_by_rmse$family
# poisson
## fit the model
tmp_fit <-
best_by_rmse %>%
trending::fit(dat)
tmp_fit
#<trending_fit_tbl> 1 x 3
# result warnings errors
# <list> <list> <list>
# 1 <glm> <NULL> <NULL>
# so the tibble ret will be
ret <- dplyr::tibble(model_class = best_by_rmse$family,
fit = tmp_fit,
location = unique(dat$location),
data = tidyr::nest(dat, fit_data = dplyr::everything()))
@dwill023 can you please post this as a true reprex()
and include sessionInfo(). i cant get this code to run on my system right now. would help to see what package versions youre using
Below is the reprex of what I was running for the new trendeval API
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
library(trendeval)
x <- rnorm(100, mean = 0)
y <- rpois(n = 100, lambda = exp(x + 1))
dat <- data.frame(x = x, y = y)
models <- list(
poisson_model = trending::glm_model(y ~ x, poisson),
linear_model = trending::lm_model(y ~ x)
)
res <- evaluate_resampling(models, dat)
head(res)
#> # A tibble: 6 × 10
#> model_name metric result warnings errors model fitting_warnings$warn…¹
#> <chr> <chr> <dbl> <list> <list> <list> <list>
#> 1 poisson_model rmse 9.36 <NULL> <NULL> <glm_trn_> <NULL>
#> 2 poisson_model rmse 2.34 <NULL> <NULL> <glm_trn_> <NULL>
#> 3 poisson_model rmse 0.396 <NULL> <NULL> <glm_trn_> <NULL>
#> 4 poisson_model rmse 0.599 <NULL> <NULL> <glm_trn_> <NULL>
#> 5 poisson_model rmse 0.334 <NULL> <NULL> <glm_trn_> <NULL>
#> 6 poisson_model rmse 0.503 <NULL> <NULL> <glm_trn_> <NULL>
#> # ℹ abbreviated name: ¹fitting_warnings$warnings
#> # ℹ 3 more variables: fitting_errors <tibble[,1]>,
#> # predicting_warnings <tibble[,1]>, predicting_errors <tibble[,1]>
best_by_rmse <-
res %>%
# dplyr::filter(purrr::map_lgl(warning, is.null)) %>% # remove models that gave warnings
dplyr::filter(purrr::map_lgl(errors, is.null)) %>% # remove models that errored
dplyr::slice_min(result) %>%
dplyr::select(model) %>%
purrr::pluck(1,1)
best_by_rmse
#> Untrained trending model:
#> lm(formula = y ~ x)
best_by_rmse$family
#> NULL
## fit the model
tmp_fit <-
best_by_rmse %>%
trending::fit(dat)
tmp_fit
#> <trending_fit_tbl> 1 x 3
#> result warnings errors
#> <list> <list> <list>
#> 1 <lm> <NULL> <NULL>
tmp_fit$result
#> [[1]]
#>
#> Call:
#> lm(formula = y ~ x, data = dat)
#>
#> Coefficients:
#> (Intercept) x
#> 5.043 4.860
sessionInfo()
#> R version 4.3.0 (2023-04-21 ucrt)
#> Platform: x86_64-w64-mingw32/x64 (64-bit)
#> Running under: Windows 10 x64 (build 19045)
#>
#> Matrix products: default
#>
#>
#> locale:
#> [1] LC_COLLATE=English_United States.utf8
#> [2] LC_CTYPE=English_United States.utf8
#> [3] LC_MONETARY=English_United States.utf8
#> [4] LC_NUMERIC=C
#> [5] LC_TIME=English_United States.utf8
#>
#> time zone: America/Denver
#> tzcode source: internal
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] trendeval_0.1.0 dplyr_1.1.2
#>
#> loaded via a namespace (and not attached):
#> [1] yardstick_1.2.0 styler_1.9.1 tidyr_1.3.0 utf8_1.2.3
#> [5] future_1.32.0 generics_0.1.3 lattice_0.21-8 listenv_0.9.0
#> [9] lme4_1.1-33 digest_0.6.31 magrittr_2.0.3 evaluate_0.21
#> [13] grid_4.3.0 fastmap_1.1.1 R.oo_1.25.0 R.cache_0.16.0
#> [17] Matrix_1.5-4 R.utils_2.12.2 purrr_1.0.1 fansi_1.0.4
#> [21] codetools_0.2-19 abind_1.4-5 cli_3.6.1 rlang_1.1.1
#> [25] R.methodsS3_1.8.2 parallelly_1.35.0 splines_4.3.0 reprex_2.0.2
#> [29] withr_2.5.0 yaml_2.3.7 parallel_4.3.0 tools_4.3.0
#> [33] nloptr_2.0.3 coda_0.19-4 trending_0.1.0 minqa_1.2.5
#> [37] rsample_1.1.1 ciTools_0.6.1 boot_1.3-28.1 globals_0.16.2
#> [41] vctrs_0.6.2 R6_2.5.1 lifecycle_1.0.3 fs_1.6.2
#> [45] MASS_7.3-58.4 furrr_0.3.1 pkgconfig_2.0.3 pillar_1.9.0
#> [49] glue_1.6.2 Rcpp_1.0.10 arm_1.13-1 xfun_0.39
#> [53] tibble_3.2.1 tidyselect_1.2.0 rstudioapi_0.14 knitr_1.42
#> [57] htmltools_0.5.5 nlme_3.1-162 rmarkdown_2.21 compiler_4.3.0
fantastic.
looking at this and seeing one thing immediately ...
res <- evaluate_resampling(models, dat)
should be
res <-
evaluate_resampling(models, dat) %>%
summary()
the way its written now will work but its sort of cheating because it finds the best rmse of the crossvalidation ... not the best rmse of the summary across all cross validations by model type
try adding that summary()
step and i think you'll see what i mean
either way i think we can work this into the package. i'll put together a prototype of that.
kicking this one fully over to me for now. but will keep you looped in on progress @dwill023
status: ive been trying to work the updated API into the fiphde:::glm_fit()
helper function. current wall that im hitting is that the new trendeval
requires an updated version of trending
(another package) ... and that updated version of trending
has changed its API too!
more to follow.
this is fixed on v1.1.0 branch (https://github.com/signaturescience/fiphde/tree/v1.1.0)
bottom line is that the trendeval
function to evaluate the best fit model changed its API in v0.1.0 ... and the outputs of trending
(which fits the models / runs the predict method) changed in v0.1.0. these are separate (companion) packages and we had to tease apart how to use the new API to access the same model objects. should be all set.
this update roll into a v1.1.0 release along with some other improvements to the package (e.g., #179 )
it looks like the
trendeval
API has changed so that some of the functions we used to use are no longer exported:^ from GH actions CI/CD check at https://github.com/signaturescience/fiphde/actions/runs/4952653456/jobs/8859225690?pr=176
this will likely involve refactoring some of our
glm_*
functions