cmu-delphi / epipredict

Tools for building predictive models in epidemiology.
https://cmu-delphi.github.io/epipredict/
Other
8 stars 9 forks source link

`add_shifted_columns` can produce grouped output that is passed through subsequent steps #413

Open brookslogan opened 2 days ago

brookslogan commented 2 days ago

See the logic here:

  processed <- new_data %>%
    full_join(shifted, by = ok) %>%
    group_by(across(all_of(kill_time_value(ok)))) %>%
    arrange(time_value)
  if (inherits(new_data, "epi_df")) {
    processed <- processed %>%
      ungroup() %>%
      as_epi_df(
        as_of = attributes(new_data)$metadata$as_of,
        other_keys = attributes(new_data)$metadata$other_keys
      )
  }

And we do appear to have non-epi_dfs when baking:

library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(epiprocess)
#> Registered S3 method overwritten by 'tsibble':
#>   method               from 
#>   as_tibble.grouped_df dplyr
#> 
#> Attaching package: 'epiprocess'
#> The following object is masked from 'package:stats':
#> 
#>     filter
library(epipredict)
#> Loading required package: parsnip
#> Registered S3 method overwritten by 'epipredict':
#>   method            from   
#>   print.step_naomit recipes
trace(prep, quote({print(class(x));print(class(list(...)$training))}))
#> Tracing function "prep" in package "epipredict"
#> [1] "prep"
trace(bake, quote({print(class(object));print(class(list(...)$new_data))}))
#> Tracing function "bake" in package "epipredict"
#> [1] "bake"
jhu <- case_death_rate_subset %>%
  dplyr::filter(time_value >= as.Date("2021-12-01"))
out <- arx_forecaster(
  jhu, "death_rate",
  c("case_rate", "death_rate")
)
#> Tracing recipes::prep(blueprint$recipe, training = training, fresh = blueprint$fresh,  .... on entry 
#> [1] "epi_recipe" "recipe"    
#> [1] "epi_df"     "tbl_df"     "tbl"        "data.frame"
#> Tracing recipes::bake(object = rec, new_data = new_data) on entry 
#> [1] "epi_recipe" "recipe"    
#> [1] "tbl_df"     "tbl"        "data.frame"
#> Tracing bake(step, new_data = new_data) on entry 
#> [1] "step_epi_lag" "step"        
#> [1] "tbl_df"     "tbl"        "data.frame"
#> Tracing bake(step, new_data = new_data) on entry 
#> [1] "step_epi_lag" "step"        
#> [1] "grouped_df" "tbl_df"     "tbl"        "data.frame"
#> Tracing bake(step, new_data = new_data) on entry 
#> [1] "step_epi_ahead" "step"          
#> [1] "grouped_df" "tbl_df"     "tbl"        "data.frame"
#> Tracing bake(step, new_data = new_data) on entry 
#> [1] "step_naomit" "step"       
#> [1] "grouped_df" "tbl_df"     "tbl"        "data.frame"

Created on 2024-10-16 with reprex v2.1.1

The steps after the lags&aheads here seem like they are the same when grouped vs. ungrouped, so maybe there's no immediate problem in arx_forecaster(). But that won't always be the case.

dajmcdon commented 2 days ago

@dsweber2 We should always ungroup() after baking.