cmu-delphi / epipredict

Tools for building predictive models in epidemiology.
https://cmu-delphi.github.io/epipredict/
Other
8 stars 8 forks source link

Raise warning/error on bad ahead/lagsets (on training data) #332

Open brookslogan opened 1 month ago

brookslogan commented 1 month ago

When working with weekly data, it's easy to mess up the specification of lagsets and aheadsets, but we only get a a confusing/imprecise error message about 0 non-NA cases instead, and it seems quite challenging to debug errors through steps and layers, especially with S3 involved. (epiprocess#342 is also relevant here, although since the error happens right off the bat, the current way of just letting errors pass through allows more debugging tools such as recover(), though it's not really helpful in this case.)

library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(tidyr)
library(purrr)
library(ggplot2)
library(epidatr)
#> ! epidatr cache is being used (set env var EPIDATR_USE_CACHE=FALSE if not
#>   intended).
#> ℹ The cache directory is ~/.cache/R/epidatr.
#> ℹ The cache will be cleared after 14 days and will be pruned if it exceeds 4096
#>   MB.
#> ℹ The log of cache transactions is stored at ~/.cache/R/epidatr/logfile.txt.
library(epiprocess)
#> 
#> Attaching package: 'epiprocess'
#> The following object is masked from 'package:stats':
#> 
#>     filter
library(epipredict)
#> Loading required package: parsnip
#> 
#> Attaching package: 'epipredict'
#> The following object is masked from 'package:ggplot2':
#> 
#>     layer

flusurv_analysis_issue <- as.Date("2019-08-01") %>%
  MMWRweek::MMWRweek() %>%
  {.$MMWRyear * 100L + .$MMWRweek}

flusurv_issue_data <-
  pub_flusurv(
    locations = "network_all",
    issues = epirange(123401, flusurv_analysis_issue)
  )
#> Warning: Loading from the cache at /home/fullname/.cache/R/epidatr; see
#> ~/.cache/R/epidatr/logfile.txt for more details.
#> This warning is displayed once every 8 hours.

flusurv_archive <- flusurv_issue_data %>%
  select(geo_value = location,
         time_value = epiweek,
         version = release_date,
         starts_with("rate_")) %>%
  as_epi_archive(compactify = TRUE)

archive <- flusurv_archive

forecast_dates <- seq(min(archive$DT$version) + 120L, archive$versions_end,
                      by = "6 weeks")

horizons <- 1 + c(0, 7, 14, 21, 28) # relative to forecast_date

example_forecaster <- function(snapshot_edf, forecast_date) {
  # shared_reporting_latency <- as.integer(forecast_date - max(snapshot_edf$time_value))
  horizons %>%
    map(function(horizon) {
      snapshot_edf %>%
        arx_forecaster(
          outcome = "rate_overall",
          predictors = "rate_overall",
          args_list = arx_args_list(
            # (this is incomplete; latency often varies signficantly by covariate and can't be ignored, so we also need lag adjustment.)
            ahead = horizon, # <-- oops, forgot latency adjustment
            quantile_levels = c(0.1, 0.5, 0.9),
            forecast_date = forecast_date,
            target_date = forecast_date + horizon
          )) %>%
        .$predictions
    }) %>%
    bind_rows()
  ## list()
}

pseudoprospective_forecasts <-
  archive %>%
  epix_slide(
    ref_time_values = forecast_dates,
    before = 365000L, # 1000-year time window --> don't filter out any `time_value`s
    ~ example_forecaster(.x, .ref_time_value),
    names_sep = NULL
  ) %>%
  select(-time_value)
#> Warning in max.default(structure(numeric(0), class = "Date"), na.rm = FALSE):
#> no non-missing arguments to max; returning -Inf
#> Error in `map()` at rlang/R/dots.R:91:3:
#> ℹ In index: 1.
#> Caused by error in `lm.fit()`:
#> ! 0 (non-NA) cases
#> Backtrace:
#>      ▆
#>   1. ├─... %>% select(-time_value)
#>   2. ├─dplyr::select(., -time_value)
#>   3. ├─epiprocess::epix_slide(...) at dplyr/R/select.R:54:3
#>   4. │ └─x$slide(...)
#>   5. │   ├─... %>% ungroup()
#>   6. │   └─self$group_by()$slide(...) at dplyr/R/group-by.R:153:3
#>   7. │     └─base::lapply(...)
#>   8. │       └─epiprocess (local) FUN(X[[i]], ...)
#>   9. │         ├─dplyr::group_modify(...)
#>  10. │         ├─epiprocess:::group_modify.epi_df(...) at dplyr/R/group-map.R:156:3
#>  11. │         │ └─dplyr::dplyr_reconstruct(NextMethod(), .data)
#>  12. │         │   └─dplyr:::dplyr_new_data_frame(data) at dplyr/R/generics.R:196:3
#>  13. │         │     ├─row.names %||% .row_names_info(x, type = 0L) at dplyr/R/utils.R:18:3
#>  14. │         │     └─base::.row_names_info(x, type = 0L) at dplyr/R/utils.R:18:3
#>  15. │         ├─base::NextMethod()
#>  16. │         └─dplyr:::group_modify.data.frame(...)
#>  17. │           └─epiprocess (local) .f(.data, group_keys(.data), ...) at dplyr/R/group-map.R:166:3
#>  18. │             └─f(.data_group, .group_key, ref_time_value, ...)
#>  19. │               └─global example_forecaster(.x, .ref_time_value)
#>  20. │                 └─... %>% bind_rows()
#>  21. ├─dplyr::ungroup(.)
#>  22. ├─dplyr::bind_rows(.)
#>  23. │ └─rlang::list2(...) at dplyr/R/bind-rows.R:31:3
#>  24. ├─purrr::map(...) at rlang/R/dots.R:91:3
#>  25. │ └─purrr:::map_("list", .x, .f, ..., .progress = .progress) at purrr/R/map.R:129:3
#>  26. │   ├─purrr:::with_indexed_errors(...) at purrr/R/map.R:174:3
#>  27. │   │ └─base::withCallingHandlers(...) at purrr/R/map.R:201:3
#>  28. │   ├─purrr:::call_with_cleanup(...) at purrr/R/map.R:174:3
#>  29. │   └─.f(.x[[i]], ...)
#>  30. │     └─... %>% .$predictions
#>  31. ├─epipredict::arx_forecaster(...)
#>  32. │ ├─generics::fit(wf, epi_data)
#>  33. │ ├─epipredict:::fit.epi_workflow(wf, epi_data)
#>  34. │ ├─base::NextMethod()
#>  35. │ └─workflows:::fit.workflow(wf, epi_data)
#>  36. │   └─workflows::.fit_model(workflow, control)
#>  37. │     ├─generics::fit(action_model, workflow = workflow, control = control)
#>  38. │     └─workflows:::fit.action_model(...)
#>  39. │       └─workflows:::fit_from_xy(spec, mold, case_weights, control_parsnip)
#>  40. │         ├─generics::fit_xy(...)
#>  41. │         └─parsnip::fit_xy.model_spec(...)
#>  42. │           └─parsnip:::xy_form(...)
#>  43. │             └─parsnip:::form_form(...)
#>  44. │               └─parsnip:::eval_mod(...)
#>  45. │                 └─rlang::eval_tidy(e, env = envir, ...)
#>  46. ├─stats::lm(formula = ..y ~ ., data = data) at rlang/R/eval-tidy.R:121:3
#>  47. │ └─stats::lm.fit(...)
#>  48. │   └─base::stop("0 (non-NA) cases")
#>  49. └─base::.handleSimpleError(...)
#>  50.   └─purrr (local) h(simpleError(msg, call))
#>  51.     └─cli::cli_abort(...) at purrr/R/map.R:215:9
#>  52.       └─rlang::abort(...) at cli/R/rlang.R:45:3

Created on 2024-05-14 with reprex v2.0.2

The logic needed is probably a bit more complicated than one might think, since the aheadset and lagset don't necessarily include 0; you can't just do some pre-lag-calculation check of all(is.na( <something> )) or not(all( <x> %in% <y> )). I suspect, e.g., for step_epi_lag that this requires checking (A --- that none of the predictors coming in, potentially including some not being shifted, wasn't already all NAs, to prevent confusing error messages from B ---) that the output lagged signals, other (unshifted) predictors, (maybe other things with roles?,) and, when training, the outcomes, have at least some overlapping non-NA rows. And messaging about it helpfully is probably even harder; maybe there could be some output of a section of the output df merged with the original df (to include original versions of the shifted signals) including at least one non-NA for each shifted output, if there are any (think A + our lagging method should ensure this?), so they can see where they don't line up?

Whether this should be a warning or an error probably depends on whether this is recoverable via other steps.

dajmcdon commented 2 weeks ago

Can you make a simple example without the slide?