cmu-delphi / epipredict

Tools for building predictive models in epidemiology.
https://cmu-delphi.github.io/epipredict/
Other
10 stars 10 forks source link

`layer_population_scaling` with `by = NULL` doesn't work with any `other_keys` #410

Closed brookslogan closed 1 week ago

brookslogan commented 1 month ago

See below reprex. step_population_scaling seems like it works, but layer_population_scaling has issues.

Cause is either

The documentation may not match the intention of the implementation; I think it says by = NULL selects all overlapping columns, rather than talking about epikeys in particular.

suppressPackageStartupMessages({
  library(dplyr)
  library(epiprocess)
  library(epipredict)
  library(testthat)
})

# XXX deriving from test-population_scaling.R

jhu <- case_death_rate_subset %>%
  dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
  dplyr::select(geo_value, time_value, case_rate)

# pretend there's 1 geo with many age groups:
jhu2 <- jhu %>%
  as_tibble() %>%
  mutate(age_group = geo_value, geo_value = 1) %>%
  as_epi_df(as_of = attr(jhu, "metadata")$as_of, other_keys = "age_group")

jhu3 <- jhu2

reverse_pop_data <- data.frame(
  geo_value = c("ca", "ny"),
  values = c(1 / 20000, 1 / 30000)
)

reverse_pop_data2 <- reverse_pop_data %>%
  mutate(age_group = geo_value, geo_value = 1)

reverse_pop_data3 <- reverse_pop_data2

r <- epi_recipe(jhu) %>%
  step_population_scaling(case_rate,
                          df = reverse_pop_data,
                          df_pop_col = "values",
                          by = NULL,
                          suffix = "_scaled"
                          ) %>%
  step_epi_lag(case_rate_scaled, lag = c(0, 7, 14)) %>% # cases
  step_epi_ahead(case_rate_scaled, ahead = 7, role = "outcome") %>% # cases
  recipes::step_naomit(recipes::all_predictors()) %>%
  recipes::step_naomit(recipes::all_outcomes(), skip = TRUE)

r2 <- epi_recipe(jhu2) %>%
  step_population_scaling(case_rate,
                          df = reverse_pop_data2,
                          df_pop_col = "values",
                          by = NULL,
                          suffix = "_scaled"
                          ) %>%
  step_epi_lag(case_rate_scaled, lag = c(0, 7, 14)) %>% # cases
  step_epi_ahead(case_rate_scaled, ahead = 7, role = "outcome") %>% # cases
  recipes::step_naomit(recipes::all_predictors()) %>%
  recipes::step_naomit(recipes::all_outcomes(), skip = TRUE)

r3 <- epi_recipe(jhu3) %>%
  step_population_scaling(case_rate,
                          df = reverse_pop_data3,
                          df_pop_col = "values",
                          by = c("geo_value", "age_group"),
                          suffix = "_scaled"
                          ) %>%
  step_epi_lag(case_rate_scaled, lag = c(0, 7, 14)) %>% # cases
  step_epi_ahead(case_rate_scaled, ahead = 7, role = "outcome") %>% # cases
  recipes::step_naomit(recipes::all_predictors()) %>%
  recipes::step_naomit(recipes::all_outcomes(), skip = TRUE)

prep <- prep(r, jhu)

prep2 <- prep(r2, jhu2)

prep3 <- prep(r3, jhu3)

b <- bake(prep, jhu)

b2 <- bake(prep2, jhu2)

b3 <- bake(prep3, jhu3)

expect_equal(b, b2 %>% mutate(geo_value = age_group, age_group = NULL) %>%
                  # geo_type metadata is wrong; re-infer:
                  as_tibble() %>%
                  as_epi_df(as_of = attr(b, "metadata")$as_of))

expect_equal(b, b3 %>% mutate(geo_value = age_group, age_group = NULL) %>%
                  # geo_type metadata is wrong; re-infer:
                  as_tibble() %>%
                  as_epi_df(as_of = attr(b, "metadata")$as_of))

f <- frosting() %>%
  layer_predict() %>%
  layer_threshold(.pred) %>%
  layer_naomit(.pred) %>%
  layer_population_scaling(.pred,
                           df = reverse_pop_data,
                           by = NULL,
                           df_pop_col = "values"
                           )

f2 <- frosting() %>%
  layer_predict() %>%
  layer_threshold(.pred) %>%
  layer_naomit(.pred) %>%
  layer_population_scaling(.pred,
                           df = reverse_pop_data2,
                           by = NULL,
                           df_pop_col = "values"
                           )

f3 <- frosting() %>%
  layer_predict() %>%
  layer_threshold(.pred) %>%
  layer_naomit(.pred) %>%
  layer_population_scaling(.pred,
                           df = reverse_pop_data3,
                           by = c("geo_value", "age_group"),
                           df_pop_col = "values"
                           )

wf <- epi_workflow(
  r,
  parsnip::linear_reg()
) %>%
  fit(jhu) %>%
  add_frosting(f)

wf2 <- epi_workflow(
  r2,
  parsnip::linear_reg()
) %>%
  fit(jhu2) %>%
  add_frosting(f2)

wf3 <- epi_workflow(
  r3,
  parsnip::linear_reg()
) %>%
  fit(jhu3) %>%
  add_frosting(f3)

wf23 <- epi_workflow(
  r2,
  parsnip::linear_reg()
) %>%
  fit(jhu2) %>%
  add_frosting(f3)

wf32 <- epi_workflow(
  r3,
  parsnip::linear_reg()
) %>%
  fit(jhu3) %>%
  add_frosting(f2)

latest <- get_test_data(
  recipe = r,
  x = case_death_rate_subset %>%
    dplyr::filter(
      time_value > "2021-11-01",
      geo_value %in% c("ca", "ny")
    ) %>%
    dplyr::select(geo_value, time_value, case_rate)
)

latest2 <- get_test_data(
  recipe = r2,
  x = case_death_rate_subset %>%
    dplyr::filter(
      time_value > "2021-11-01",
      geo_value %in% c("ca", "ny")
    ) %>%
    mutate(age_group = geo_value, geo_value = 1) %>%
    dplyr::select(geo_value, age_group, time_value, case_rate) %>%
    as_tibble() %>%
    as_epi_df(as_of = attr(case_death_rate_subset, "metadata")$as_of,
              other_keys = "age_group")
)

latest3 <- get_test_data(
  recipe = r3,
  x =
    case_death_rate_subset %>%
    dplyr::filter(
      time_value > "2021-11-01",
      geo_value %in% c("ca", "ny")
    ) %>%
    mutate(age_group = geo_value, geo_value = 1) %>%
    dplyr::select(geo_value, age_group, time_value, case_rate) %>%
    as_tibble() %>%
    as_epi_df(as_of = attr(case_death_rate_subset, "metadata")$as_of,
              other_keys = "age_group")
)

expect_equal(latest %>%
               # is missing `other_keys`; regenerate to standardize
               as_tibble() %>%
               as_epi_df(as_of = attr(latest, "metadata")$as_of)
            ,
             latest2 %>% mutate(geo_value = age_group, age_group = NULL) %>%
               # geo_type metadata is wrong; re-infer:
               as_tibble() %>%
               as_epi_df(as_of = attr(latest, "metadata")$as_of))

expect_equal(latest %>%
               # is missing `other_keys`; regenerate to standardize
               as_tibble() %>%
               as_epi_df(as_of = attr(latest, "metadata")$as_of)
            ,
             latest3 %>% mutate(geo_value = age_group, age_group = NULL) %>%
               # geo_type metadata is wrong; re-infer:
               as_tibble() %>%
               as_epi_df(as_of = attr(latest, "metadata")$as_of))

p <- predict(wf, latest)

p2 <- predict(wf2, latest2)
#> Error in `vctrs::vec_locate_matches()`:
#> ! Each value of `needles` must have a match in `haystack`.
#> ✖ Location 5 of `needles` does not have a match.

p3 <- predict(wf3, latest3)

expect_equal(p, p3 %>% mutate(geo_value = age_group, age_group = NULL) %>%
                  # geo_type metadata is wrong; re-infer:
                  as_tibble() %>%
                  as_epi_df(as_of = attr(p, "metadata")$as_of))

p23 <- predict(wf23, latest2)
#> Error in `vctrs::vec_locate_matches()`:
#> ! Each value of `needles` must have a match in `haystack`.
#> ✖ Location 5 of `needles` does not have a match.

p32 <- predict(wf32, latest3)

expect_equal(p, p32 %>% mutate(geo_value = age_group, age_group = NULL) %>%
                  # geo_type metadata is wrong; re-infer:
                  as_tibble() %>%
                  as_epi_df(as_of = attr(p, "metadata")$as_of))

Created on 2024-10-09 with reprex v2.1.1