signaturescience / fiphde

Forecasting Influenza in Support of Public Health Decision Making
https://signaturescience.github.io/fiphde/
GNU General Public License v3.0
3 stars 1 forks source link

trendeval API change #177

Closed vpnagraj closed 1 year ago

vpnagraj commented 1 year ago

it looks like the trendeval API has changed so that some of the functions we used to use are no longer exported:

Error: Error: 'evaluate_models' is not an exported object from 'namespace:trendeval'

^ from GH actions CI/CD check at https://github.com/signaturescience/fiphde/actions/runs/4952653456/jobs/8859225690?pr=176

this will likely involve refactoring some of our glm_* functions

vpnagraj commented 1 year ago

@dwill023 im going to leave myself assigned here too ... but please get started on this one.

for context: we had been using the evaluate_models() function from trendeval to identify the "best fit" model automatically from a list of models:

https://github.com/signaturescience/fiphde/blob/main/R/glm.R#L26-L33

you dont have to dive too deep in the weeds on our API for now.

the first question you should try to address is ... what is the comparable functionality in trendeval now that evaluate_models() is no longer in the package?

a few breadcrumbs to get started:

vpnagraj commented 1 year ago

and dont feel like you need to push code to this repository. if you can put together a reproducible example of both APIs you could post the example code / output in this thread. that would get us well underway and then we can operationalize how to change our API from there. but we need to start with that example of the new API.

dwill023 commented 1 year ago

New API:

x <- rnorm(100, mean = 0)
y <- rpois(n = 100, lambda = exp(x + 1))
dat <- data.frame(x = x, y = y)

models <- list(
  poisson_model = trending::glm_model(y ~ x, poisson),
  linear_model = trending::lm_model(y ~ x)
)

res <- evaluate_resampling(models, dat)

res
#Rows: 200
#Columns: 10
#$ model_name          <chr> "poisson_model", "poisson_model", "poisson_model", "poisson_model", "poisson_model", "poisson_model", "poisson_model…
#$ metric              <chr> "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rmse", "rms…
#$ result              <dbl> 3.567591748, 1.221492528, 0.597965961, 1.604374413, 0.255587393, 2.483827026, 1.230081306, 0.350645465, 0.004610133,…
#$ warnings            <list> <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NU…
#$ errors              <list> <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NULL>, <NU…
#$ model               <list> <Untrained trending model:,     glm(formula = y ~ x, family = poisson)>, <Untrained trending model:,     glm(formul…
#$ fitting_warnings    <tibble[,1]> <tbl_df[48 x 1]>
#$ fitting_errors      <tibble[,1]> <tbl_df[48 x 1]>
#$ predicting_warnings <tibble[,1]> <tbl_df[48 x 1]>
#$ predicting_errors   <tibble[,1]> <tbl_df[48 x 1]>

best_by_rmse <-
  res %>%
  # dplyr::filter(purrr::map_lgl(warning, is.null)) %>%  # remove models that gave warnings
  dplyr::filter(purrr::map_lgl(errors, is.null))  %>%   # remove models that errored
  dplyr::slice_min(result) %>%
  dplyr::select(model) %>%
  purrr::pluck(1,1)

best_by_rmse
# Untrained trending model:
#    glm(formula = y ~ x, family = poisson)

best_by_rmse$family
# poisson

## fit the model
tmp_fit <-
  best_by_rmse %>%
  trending::fit(dat)

tmp_fit
#<trending_fit_tbl> 1 x 3
#  result warnings errors
#  <list> <list>   <list>
# 1 <glm>  <NULL>   <NULL>

# so the tibble ret will be
ret <- dplyr::tibble(model_class = best_by_rmse$family,
                     fit = tmp_fit,
                     location = unique(dat$location),
                     data = tidyr::nest(dat, fit_data = dplyr::everything()))
vpnagraj commented 1 year ago

@dwill023 can you please post this as a true reprex() and include sessionInfo(). i cant get this code to run on my system right now. would help to see what package versions youre using

dwill023 commented 1 year ago

Below is the reprex of what I was running for the new trendeval API

library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(trendeval)

x <- rnorm(100, mean = 0)
y <- rpois(n = 100, lambda = exp(x + 1))
dat <- data.frame(x = x, y = y)

models <- list(
  poisson_model = trending::glm_model(y ~ x, poisson),
  linear_model = trending::lm_model(y ~ x)
)

res <- evaluate_resampling(models, dat)
head(res)
#> # A tibble: 6 × 10
#>   model_name    metric result warnings errors model      fitting_warnings$warn…¹
#>   <chr>         <chr>   <dbl> <list>   <list> <list>     <list>                 
#> 1 poisson_model rmse    9.36  <NULL>   <NULL> <glm_trn_> <NULL>                 
#> 2 poisson_model rmse    2.34  <NULL>   <NULL> <glm_trn_> <NULL>                 
#> 3 poisson_model rmse    0.396 <NULL>   <NULL> <glm_trn_> <NULL>                 
#> 4 poisson_model rmse    0.599 <NULL>   <NULL> <glm_trn_> <NULL>                 
#> 5 poisson_model rmse    0.334 <NULL>   <NULL> <glm_trn_> <NULL>                 
#> 6 poisson_model rmse    0.503 <NULL>   <NULL> <glm_trn_> <NULL>                 
#> # ℹ abbreviated name: ¹​fitting_warnings$warnings
#> # ℹ 3 more variables: fitting_errors <tibble[,1]>,
#> #   predicting_warnings <tibble[,1]>, predicting_errors <tibble[,1]>

best_by_rmse <-
  res %>%
  # dplyr::filter(purrr::map_lgl(warning, is.null)) %>%  # remove models that gave warnings
  dplyr::filter(purrr::map_lgl(errors, is.null))  %>%   # remove models that errored
  dplyr::slice_min(result) %>%
  dplyr::select(model) %>%
  purrr::pluck(1,1)

best_by_rmse
#> Untrained trending model:
#>     lm(formula = y ~ x)
best_by_rmse$family
#> NULL

## fit the model
tmp_fit <-
  best_by_rmse %>%
  trending::fit(dat)

tmp_fit
#> <trending_fit_tbl> 1 x 3
#>   result warnings errors
#>   <list> <list>   <list>
#> 1 <lm>   <NULL>   <NULL>
tmp_fit$result
#> [[1]]
#> 
#> Call:
#> lm(formula = y ~ x, data = dat)
#> 
#> Coefficients:
#> (Intercept)            x  
#>       5.043        4.860

sessionInfo()
#> R version 4.3.0 (2023-04-21 ucrt)
#> Platform: x86_64-w64-mingw32/x64 (64-bit)
#> Running under: Windows 10 x64 (build 19045)
#> 
#> Matrix products: default
#> 
#> 
#> locale:
#> [1] LC_COLLATE=English_United States.utf8 
#> [2] LC_CTYPE=English_United States.utf8   
#> [3] LC_MONETARY=English_United States.utf8
#> [4] LC_NUMERIC=C                          
#> [5] LC_TIME=English_United States.utf8    
#> 
#> time zone: America/Denver
#> tzcode source: internal
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] trendeval_0.1.0 dplyr_1.1.2    
#> 
#> loaded via a namespace (and not attached):
#>  [1] yardstick_1.2.0   styler_1.9.1      tidyr_1.3.0       utf8_1.2.3       
#>  [5] future_1.32.0     generics_0.1.3    lattice_0.21-8    listenv_0.9.0    
#>  [9] lme4_1.1-33       digest_0.6.31     magrittr_2.0.3    evaluate_0.21    
#> [13] grid_4.3.0        fastmap_1.1.1     R.oo_1.25.0       R.cache_0.16.0   
#> [17] Matrix_1.5-4      R.utils_2.12.2    purrr_1.0.1       fansi_1.0.4      
#> [21] codetools_0.2-19  abind_1.4-5       cli_3.6.1         rlang_1.1.1      
#> [25] R.methodsS3_1.8.2 parallelly_1.35.0 splines_4.3.0     reprex_2.0.2     
#> [29] withr_2.5.0       yaml_2.3.7        parallel_4.3.0    tools_4.3.0      
#> [33] nloptr_2.0.3      coda_0.19-4       trending_0.1.0    minqa_1.2.5      
#> [37] rsample_1.1.1     ciTools_0.6.1     boot_1.3-28.1     globals_0.16.2   
#> [41] vctrs_0.6.2       R6_2.5.1          lifecycle_1.0.3   fs_1.6.2         
#> [45] MASS_7.3-58.4     furrr_0.3.1       pkgconfig_2.0.3   pillar_1.9.0     
#> [49] glue_1.6.2        Rcpp_1.0.10       arm_1.13-1        xfun_0.39        
#> [53] tibble_3.2.1      tidyselect_1.2.0  rstudioapi_0.14   knitr_1.42       
#> [57] htmltools_0.5.5   nlme_3.1-162      rmarkdown_2.21    compiler_4.3.0
vpnagraj commented 1 year ago

fantastic.

looking at this and seeing one thing immediately ...

res <- evaluate_resampling(models, dat)

should be

res <- 
evaluate_resampling(models, dat) %>%
summary()

the way its written now will work but its sort of cheating because it finds the best rmse of the crossvalidation ... not the best rmse of the summary across all cross validations by model type

try adding that summary() step and i think you'll see what i mean

either way i think we can work this into the package. i'll put together a prototype of that.

vpnagraj commented 1 year ago

kicking this one fully over to me for now. but will keep you looped in on progress @dwill023

status: ive been trying to work the updated API into the fiphde:::glm_fit() helper function. current wall that im hitting is that the new trendeval requires an updated version of trending (another package) ... and that updated version of trending has changed its API too!

more to follow.

vpnagraj commented 1 year ago

this is fixed on v1.1.0 branch (https://github.com/signaturescience/fiphde/tree/v1.1.0)

bottom line is that the trendeval function to evaluate the best fit model changed its API in v0.1.0 ... and the outputs of trending (which fits the models / runs the predict method) changed in v0.1.0. these are separate (companion) packages and we had to tease apart how to use the new API to access the same model objects. should be all set.

this update roll into a v1.1.0 release along with some other improvements to the package (e.g., #179 )