cmu-delphi / epipredict

Tools for building predictive models in epidemiology.
https://cmu-delphi.github.io/epipredict/
Other
8 stars 8 forks source link

Look into using `{butcher}` to reduce fitted object size #302

Open dshemetov opened 3 months ago

dshemetov commented 3 months ago

Aside: if we are super concerned with space (not clear we are at the moment, seems a "nice to have" rather than a "mandatory for adding new features"), we may want to investigate ways to use {butcher} for existing workflows.

Originally posted by @dajmcdon in https://github.com/cmu-delphi/epipredict/issues/293#issuecomment-1936620136

A bit strange that the lm fit object here is way larger than the training dataset. I'm guessing there's per-quantile duplication going on here.

``` r
library(epipredict)
#> Loading required package: epiprocess
#> 
#> Attaching package: 'epiprocess'
#> The following object is masked from 'package:stats':
#> 
#>     filter
#> Loading required package: parsnip

# Basic fitting example
jhu <- case_death_rate_subset
r <- epi_recipe(jhu) %>%
  step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
  step_epi_ahead(death_rate, ahead = 7) %>%
  step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
  step_epi_naomit()
wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
latest <- jhu %>% dplyr::filter(time_value >= max(time_value) - 14)
preds <- predict(wf, latest)

# Recursively apply a function and flatten the result
apply_nested_flatten <- function(nested_list, func, depth = 1) {
  lapply(nested_list, function(item) {
    if (is.list(item) && depth > 1) {
      apply_nested_flatten(item, func, depth - 1)
    } else {
      func(item)
    }
  }) %>% `names<-`(names(nested_list)) %>% purrr::list_flatten()
}

# Get a sense for the object sizes
lobstr::obj_sizes(!!!list(jhu = jhu, r = r, wf = wf, latest = latest, preds = preds))
#> jhu   : 661.51 kB
#> r     :  16.74 kB
#> wf    :   4.55 MB
#> latest:  27.70 kB
#> preds :   2.26 kB
# Inspect the wf object, note that the fit object is the largest
lobstr::obj_sizes(!!!apply_nested_flatten(wf, function(x) x, 3))
#> pre_actions_recipe     :  20.75 kB
#> pre_mold_predictors    : 928.59 kB
#> pre_mold_outcomes      : 155.10 kB
#> pre_mold_blueprint     :  14.94 kB
#> pre_mold_extras_roles  : 622.83 kB
#> pre_case_weights       :       0 B
#> fit_actions_model      :   2.58 kB
#> fit_fit_lvl            :       0 B
#> fit_fit_spec           :  19.69 kB
#> fit_fit_fit            :   2.80 MB
#> fit_fit_preproc_x_var  :       0 B
#> fit_fit_preproc_y_var  :       0 B
#> fit_fit_elapsed_elapsed:      56 B
#> fit_meta_max_time_value:     112 B
#> fit_meta_as_of         :       0 B
#> trained                :       0 B
# Go deeper
lobstr::obj_sizes(!!!apply_nested_flatten(wf, function(x) x, 4))
#> pre_actions_recipe_recipe                      :  18.54 kB
#> pre_actions_recipe_blueprint                   :   1.65 kB
#> pre_mold_predictors_lag_0_death_rate           : 154.61 kB
#> pre_mold_predictors_lag_7_death_rate           : 154.61 kB
#> pre_mold_predictors_lag_14_death_rate          : 154.61 kB
#> pre_mold_predictors_lag_0_case_rate            : 154.61 kB
#> pre_mold_predictors_lag_7_case_rate            : 154.61 kB
#> pre_mold_predictors_lag_14_case_rate           : 154.61 kB
#> pre_mold_outcomes_ahead_7_death_rate           : 154.61 kB
#> pre_mold_blueprint_intercept                   :       0 B
#> pre_mold_blueprint_allow_novel_levels          :       0 B
#> pre_mold_blueprint_composition                 :       0 B
#> pre_mold_blueprint_ptypes_predictors           :     392 B
#> pre_mold_blueprint_ptypes_outcomes             :     296 B
#> pre_mold_blueprint_fresh                       :       0 B
#> pre_mold_blueprint_strings_as_factors          :       0 B
#> pre_mold_blueprint_recipe                      :  12.58 kB
#> pre_mold_blueprint_extra_role_ptypes_time_value:     464 B
#> pre_mold_blueprint_extra_role_ptypes_geo_value :     408 B
#> pre_mold_blueprint_extra_role_ptypes_raw       :     472 B
#> pre_mold_extras_roles_time_value               : 155.08 kB
#> pre_mold_extras_roles_geo_value                : 157.97 kB
#> pre_mold_extras_roles_raw                      : 309.57 kB
#> pre_case_weights                               :       0 B
#> fit_actions_model_spec                         :   2.08 kB
#> fit_actions_model_formula                      :       0 B
#> fit_fit_lvl                                    :       0 B
#> fit_fit_spec_args_penalty                      :       0 B
#> fit_fit_spec_args_mixture                      :       0 B
#> fit_fit_spec_mode                              :       0 B
#> fit_fit_spec_user_specified_mode               :       0 B
#> fit_fit_spec_method_libs                       :     112 B
#> fit_fit_spec_method_fit                        :   1.52 kB
#> fit_fit_spec_method_pred                       :  17.37 kB
#> fit_fit_spec_engine                            :       0 B
#> fit_fit_spec_user_specified_engine             :       0 B
#> fit_fit_fit_coefficients                       :     344 B
#> fit_fit_fit_residuals                          : 155.98 kB
#> fit_fit_fit_effects                            : 309.27 kB
#> fit_fit_fit_rank                               :      56 B
#> fit_fit_fit_fitted.values                      : 154.66 kB
#> fit_fit_fit_assign                             :      80 B
#> fit_fit_fit_qr                                 :   1.08 MB
#> fit_fit_fit_df.residual                        :      56 B
#> fit_fit_fit_call                               :   7.92 kB
#> fit_fit_fit_terms                              :   4.55 kB
#> fit_fit_fit_model                              :   1.08 MB
#> fit_fit_preproc_x_var                          :       0 B
#> fit_fit_preproc_y_var                          :      56 B
#> fit_fit_elapsed_elapsed                        :      56 B
#> fit_meta_max_time_value                        :     112 B
#> fit_meta_as_of                                 :       0 B
#> trained                                        :       0 B

# Use butcher to reduce the memory
small_lm <- butcher::butcher(wf$fit$fit$fit, verbose = TRUE)
#> ✔ Memory released: 1.24 MB
#> ✖ Disabled: `print()`, `summary()`, and `fitted()`
butcher::weigh(wf$fit$fit$fit) %>% print(n=20)
#> # A tibble: 21 × 2
#>    object                      size
#>    <chr>                      <dbl>
#>  1 terms                   1.10    
#>  2 call                    1.09    
#>  3 qr.qr                   1.08    
#>  4 effects                 0.310   
#>  5 residuals               0.156   
#>  6 fitted.values           0.156   
#>  7 model...y               0.155   
#>  8 model.lag_0_death_rate  0.155   
#>  9 model.lag_7_death_rate  0.155   
#> 10 model.lag_14_death_rate 0.155   
#> 11 model.lag_0_case_rate   0.155   
#> 12 model.lag_7_case_rate   0.155   
#> 13 model.lag_14_case_rate  0.155   
#> 14 coefficients            0.000848
#> 15 qr.qraux                0.000112
#> 16 assign                  0.00008 
#> 17 qr.pivot                0.00008 
#> 18 rank                    0.000056
#> 19 qr.tol                  0.000056
#> 20 qr.rank                 0.000056
#> # ℹ 1 more row
# Still a lot of memory used, even after butcher's cleanup
butcher::weigh(small_lm) %>% print(n=20)
#> # A tibble: 21 × 2
#>    object                      size
#>    <chr>                      <dbl>
#>  1 qr.qr                   1.08    
#>  2 effects                 0.310   
#>  3 residuals               0.156   
#>  4 model...y               0.155   
#>  5 model.lag_0_death_rate  0.155   
#>  6 model.lag_7_death_rate  0.155   
#>  7 model.lag_14_death_rate 0.155   
#>  8 model.lag_0_case_rate   0.155   
#>  9 model.lag_7_case_rate   0.155   
#> 10 model.lag_14_case_rate  0.155   
#> 11 terms                   0.00570 
#> 12 coefficients            0.000848
#> 13 qr.qraux                0.000112
#> 14 call                    0.000112
#> 15 assign                  0.00008 
#> 16 qr.pivot                0.00008 
#> 17 rank                    0.000056
#> 18 qr.tol                  0.000056
#> 19 qr.rank                 0.000056
#> 20 df.residual             0.000056
#> # ℹ 1 more row

Created on 2024-03-28 with reprex v2.0.2