tidymodels / hardhat

Construct Modeling Packages
https://hardhat.tidymodels.org
Other
101 stars 15 forks source link

Dynamically calculate weights #240

Open jamesgrecian opened 1 year ago

jamesgrecian commented 1 year ago

When fitting a model via cross-validation it is probable that the numbers of observations in each fold differs. This is an issue if the user wants to weight particular observations based on their frequency in the data set. For example in a binomial regression you might want to ensure an excess of absences are weighted down relative to the presence points.

Is there an approach to calculate these weights on the fly (i.e. during a workflow) rather than specifying them prior to splitting the data into cross-validation folds?

Here's an example of how I would calculate the weights prior to spitting the data into folds. In this case it's wrong as there are equal presence and absence locations in the dummy dataset.

I guess an alternative would be to create a function that applied the transformation to the rsample or spatialsample object?

Thanks,

James

library(sf)
#> Linking to GEOS 3.11.0, GDAL 3.5.3, PROJ 9.1.0; sf_use_s2() is TRUE
library(tidymodels)
library(spatialsample)

## Data prep:
# pak::pkg_install("Nowosad/spDataLarge")
data("lsl", "study_mask", package = "spDataLarge")
ta <- terra::rast(system.file("raster/ta.tif", package = "spDataLarge"))
lsl <- lsl |> 
  st_as_sf(coords = c("x", "y"), crs = "EPSG:32717")

# convert to 0, 1 as is typical in species distribution modelling
lsl <- lsl |> 
  mutate(lslpts = as.numeric(lslpts)-1) |>
  mutate(lslpts = factor(lslpts))

# add case weights, weighting 0's differently depending on number of presence points
lsl <- lsl %>% 
  mutate(
    case_wts = ifelse(lslpts == 1, 1, length(lslpts[lslpts == 1])/length(lslpts[lslpts == 0])),
    case_wts = importance_weights(case_wts),
    lslpts = factor(lslpts)
  )

lsl_folds <- lsl |> 
  spatial_block_cv(method = "random", v = 10)

## Fit some workflow 
glm_model <- logistic_reg() |> 
  set_engine("glm") |> 
  set_mode("classification")

glm_wflow <- workflow() |> 
  add_formula(lslpts ~ slope + cplan + cprof + elev + log10_carea) |> 
  add_model(glm_model) |> 
  add_case_weights(case_wts) |>
  fit_resamples(lsl_folds)

glm_wflow |> collect_metrics()
#> # A tibble: 2 × 6
#>   .metric  .estimator  mean     n std_err .config             
#>   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 accuracy binary     0.733    10  0.0148 Preprocessor1_Model1
#> 2 roc_auc  binary     0.798    10  0.0224 Preprocessor1_Model1

Created on 2023-05-10 with reprex v2.0.2

jamesgrecian commented 1 year ago

I've thought about this a bit more, and can show an example that calculates the weights on the fly for a simple model. However, I can't see an approach that allows you to do this using the native fit_resamples function, which is an issue if the user is trying to tune the model parameters via `tune_grid.

# packages
library(sf)
#> Linking to GEOS 3.11.0, GDAL 3.5.3, PROJ 9.1.0; sf_use_s2() is TRUE
library(tidymodels)
library(hardhat)
library(spatialsample)

# data prep
# pak::pkg_install("Nowosad/spDataLarge")
data("lsl", "study_mask", package = "spDataLarge")
ta <- terra::rast(system.file("raster/ta.tif", package = "spDataLarge"))
lsl <- lsl |> 
  st_as_sf(coords = c("x", "y"), crs = "EPSG:32717")

# convert to 0, 1 as is typical in species distribution modelling
lsl <- lsl |> 
  mutate(lslpts = as.numeric(lslpts)-1) |>
  mutate(lslpts = factor(lslpts))

# use spatialsample to generate 10 folds for CV
lsl_folds <- lsl |> 
  spatial_block_cv(method = "random", v = 10)

# an example simple workflow 
glm_model <- logistic_reg() |> 
  set_engine("glm") |> 
  set_mode("classification")

# create weights on the fly using lapply to fit model to each fold 
fit_list <- lapply(seq_along(lsl_folds$splits), function(i) {
  splt <- lsl_folds$splits[[i]]
  splt_analysis <- splt |> analysis()

  splt_analysis <- splt_analysis |>
    mutate(case_wts = ifelse(lslpts == 1, 
                             1,
                             length(lslpts[lslpts == 1])/length(lslpts[lslpts == 0])))

  splt_analysis <- splt_analysis |> mutate(case_wts = importance_weights(case_wts))

  workflow() |> 
    add_formula(lslpts ~ slope + cplan + cprof + elev + log10_carea) |> 
    add_model(glm_model) |> 
    add_case_weights(case_wts) |>
    fit(data = splt_analysis)
})

# repackage list to an object *similar* to a resamples object 
resamples <- tibble(
  splits = lsl_folds$splits,
  id = lsl_folds$id,
  model = fit_list
)

resamples
#> # A tibble: 10 × 3
#>    splits           id     model     
#>    <list>           <chr>  <list>    
#>  1 <split [326/24]> Fold01 <workflow>
#>  2 <split [320/30]> Fold02 <workflow>
#>  3 <split [300/50]> Fold03 <workflow>
#>  4 <split [322/28]> Fold04 <workflow>
#>  5 <split [289/61]> Fold05 <workflow>
#>  6 <split [329/21]> Fold06 <workflow>
#>  7 <split [307/43]> Fold07 <workflow>
#>  8 <split [321/29]> Fold08 <workflow>
#>  9 <split [314/36]> Fold09 <workflow>
#> 10 <split [321/29]> Fold10 <workflow>

Created on 2023-05-16 with reprex v2.0.2

mikemahoney218 commented 1 year ago

Howdy @jamesgrecian ! I had to poll the tidymodels crew on this one, and the idea below comes directly from @topepo .

The basic idea here is that you'd use recipes::step_mutate() to create the case weights column. This will calculate the case weights independently for each fold:

set.seed(1107)
library(sf)
library(tidymodels)
library(spatialsample)

## Data prep:
# pak::pkg_install("Nowosad/spDataLarge")
data("lsl", "study_mask", package = "spDataLarge")
ta <- terra::rast(system.file("raster/ta.tif", package = "spDataLarge"))
lsl <- lsl |> 
  st_as_sf(coords = c("x", "y"), crs = "EPSG:32717")

# convert to 0, 1 as is typical in species distribution modelling
lsl <- lsl |> 
  mutate(lslpts = factor(as.numeric(lslpts)-1)) |>
  # Creating a dummy case weights column, to get past initial verification:
  mutate(cwts = hardhat::importance_weights(NA))

# ***** Set up case weights as a recipe step here, instead *****
lsl_recipe <- recipes::recipe(
  lslpts ~ slope + cplan + cprof + elev + log10_carea, 
  data = sf::st_drop_geometry(lsl)
) |> 
  recipes::step_mutate(
    cwts = hardhat::importance_weights(
      ifelse(lslpts == 1, 1, sum(lslpts == 1) / sum(lslpts == 0))
      ),
    # Need to set the "case_weights" role explicitly:
    role = "case_weights"
  )

lsl_folds <- spatial_block_cv(lsl, method = "random", v = 10)

glm_model <- logistic_reg() |> 
  set_engine("glm") |> 
  set_mode("classification")

# Our original workflow:
glm_wflow <- workflow() |>
  add_formula(lslpts ~ slope + cplan + cprof + elev + log10_carea) |> 
  add_model(glm_model) |> 
  fit_resamples(lsl_folds)

# Using weights instead: no add_formula, because the formula is in our recipe
glm_wflow_wts <- workflow(preprocessor = lsl_recipe) |> 
  add_model(glm_model) |> 
  add_case_weights(cwts) |>
  fit_resamples(lsl_folds)

# We get _different_ results, and I _think_ it's the correct ones based on our dynamic weights:
glm_wflow |> collect_metrics()
#> # A tibble: 2 × 6
#>   .metric  .estimator  mean     n std_err .config             
#>   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 accuracy binary     0.739    10  0.0199 Preprocessor1_Model1
#> 2 roc_auc  binary     0.808    10  0.0228 Preprocessor1_Model1
glm_wflow_wts |> collect_metrics()
#> # A tibble: 2 × 6
#>   .metric  .estimator  mean     n std_err .config             
#>   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 accuracy binary     0.746    10  0.0214 Preprocessor1_Model1
#> 2 roc_auc  binary     0.808    10  0.0227 Preprocessor1_Model1

Created on 2023-05-22 with reprex v2.0.2

I am pretty sure this works; for instance, we can spot-check the first fold and confirm that we get the same AUC metric:

untidy_glm <- glm_wflow_wts$splits[[1]] |> 
  analysis() |> 
  mutate(cwts = ifelse(lslpts == 1, 1, sum(lslpts == 1) / sum(lslpts == 0))) |> 
  (\(x) {
    glm(
      lslpts ~ slope + cplan + cprof + elev + log10_carea,
      data = x,
      family = "binomial",
      weights = x$cwts
    )
  })()

yardstick::roc_auc_vec(
  (glm_wflow_wts$splits[[1]] |> assessment())$lslpts,
  predict(
    untidy_glm, 
    glm_wflow_wts$splits[[1]] |> assessment(), 
    type = "response"
  ),
  event_level = "second"
)
#> [1] 0.7574371
glm_wflow_wts$.metrics[[1]]
#> # A tibble: 2 × 4
#>   .metric  .estimator .estimate .config             
#>   <chr>    <chr>          <dbl> <chr>               
#> 1 accuracy binary         0.714 Preprocessor1_Model1
#> 2 roc_auc  binary         0.757 Preprocessor1_Model1

Created on 2023-05-22 with reprex v2.0.2

I've got a handful of things to note here, though:

When using this flexible step, use extra care to avoid data leakage in your preprocessing. Consider, for example, the transformation x = w > mean(w). When applied to new data or testing data, this transformation would use the mean of w from the new data, not the mean of w from the training data.

Point being, I'd be very careful using step_mutate() for other purposes; for calculating case weights while doing CV, however, I think it's a safe approach.

Does that all make sense?

mikemahoney218 commented 1 year ago

One other thought -- in these situations, where it's not obvious which tidymodels package is "responsible" for something, you might get faster answers over on the Community site:

https://community.rstudio.com/

That might help get some eyes on these issues a bit faster.

(But of course, it's no issue to email these posts to me if no one picks them up in a few weeks :smile:)

jamesgrecian commented 1 year ago

Thanks @mikemahoney218 and @topepo - reassured that the issues I'm facing are non-trivial!

I've worked through this today and your solution makes sense. However, when applying this to my own data I'm getting some very different AUC estimates. This is compared to previous iterations with either no weights, or weights estimated on the rough ratio of presence to absences. I'm going to try and work through the cause of this, and test whether it's a model specification problem (I'm using GAMs rather than GLMs) or an issue with the weights.

jamesgrecian commented 1 year ago

Hi @mikemahoney218 and @topepo. I wonder whether the weights are being dealt with correctly in gam_additive_mod?

If you run a glm on each fold separately you get the same AUC as when you run a glm using fit_resamples. However if you repeat the same process using mgcv::gam with no smooth terms (so the model is also fitting a linear terms only) you get different AUC values between fit_resamples and fitting to each fold seperately:

# packages
library(sf)
#> Linking to GEOS 3.11.0, GDAL 3.5.3, PROJ 9.1.0; sf_use_s2() is TRUE
library(tidymodels)
library(spatialsample)

## Data prep:
# pak::pkg_install("Nowosad/spDataLarge")
data("lsl", "study_mask", package = "spDataLarge")
ta <- terra::rast(system.file("raster/ta.tif", package = "spDataLarge"))
lsl <- lsl |> 
  st_as_sf(coords = c("x", "y"), crs = "EPSG:32717")

# convert to 0, 1 as is typical in species distribution modelling
lsl <- lsl |> 
  mutate(lslpts = factor(as.numeric(lslpts)-1)) |>
  # Creating a dummy case weights column, to get past initial verification:
  mutate(cwts = hardhat::importance_weights(NA))

# ***** Set up case weights as a recipe step here, instead *****
lsl_recipe <- recipes::recipe(
  lslpts ~ slope + cplan + cprof + elev + log10_carea, 
  data = sf::st_drop_geometry(lsl)
) |> 
  recipes::step_mutate(
    cwts = hardhat::importance_weights(
      ifelse(lslpts == 1, 1, sum(lslpts == 1) / sum(lslpts == 0))
    ),
    # Need to set the "case_weights" role explicitly:
    role = "case_weights"
  )

# split into folds
lsl_folds <- spatial_block_cv(lsl, method = "random", v = 10)

# try GLM
glm_model <- logistic_reg() |> 
  set_engine("glm") |> 
  set_mode("classification")

# Using weights instead: no add_formula, because the formula is in our recipe
glm_wflow_wts <- workflow(preprocessor = lsl_recipe) |> 
  add_model(glm_model) |> 
  add_case_weights(cwts) |>
  fit_resamples(lsl_folds)

glm_wflow_wts |> 
  unnest(.metrics) |>
  filter(.metric == "roc_auc")
#> # A tibble: 10 × 7
#>    splits           id     .metric .estimator .estimate .config         .notes  
#>    <list>           <chr>  <chr>   <chr>          <dbl> <chr>           <list>  
#>  1 <split [315/35]> Fold01 roc_auc binary         0.933 Preprocessor1_… <tibble>
#>  2 <split [310/40]> Fold02 roc_auc binary         0.776 Preprocessor1_… <tibble>
#>  3 <split [299/51]> Fold03 roc_auc binary         0.897 Preprocessor1_… <tibble>
#>  4 <split [323/27]> Fold04 roc_auc binary         0.8   Preprocessor1_… <tibble>
#>  5 <split [324/26]> Fold05 roc_auc binary         0.817 Preprocessor1_… <tibble>
#>  6 <split [316/34]> Fold06 roc_auc binary         0.822 Preprocessor1_… <tibble>
#>  7 <split [314/36]> Fold07 roc_auc binary         0.744 Preprocessor1_… <tibble>
#>  8 <split [322/28]> Fold08 roc_auc binary         0.840 Preprocessor1_… <tibble>
#>  9 <split [306/44]> Fold09 roc_auc binary         0.72  Preprocessor1_… <tibble>
#> 10 <split [320/30]> Fold10 roc_auc binary         0.790 Preprocessor1_… <tibble>

# compare with AUC values calculated on each fold seperately
for(i in 1:10){
  untidy_glm <- glm_wflow_wts$splits[[i]] |> 
    analysis() |> 
    mutate(cwts = ifelse(lslpts == 1, 1, sum(lslpts == 1) / sum(lslpts == 0))) |> 
    (\(x) {
      glm(
        lslpts ~ slope + cplan + cprof + elev + log10_carea,
        data = x,
        family = "binomial",
        weights = x$cwts
        )
      }
     )()

  print(
    round(
      yardstick::roc_auc_vec(
        (glm_wflow_wts$splits[[i]] |> assessment())$lslpts,
        predict(
          untidy_glm, 
          glm_wflow_wts$splits[[i]] |> assessment(), 
          type = "response"
          ),
        event_level = "second"
        )
      , 3)
  )
}
#> [1] 0.933
#> [1] 0.776
#> [1] 0.897
#> [1] 0.8
#> [1] 0.817
#> [1] 0.822
#> [1] 0.744
#> [1] 0.84
#> [1] 0.72
#> [1] 0.79

# now try the same with a GAM
gam_model <- gen_additive_mod() |> 
  set_engine("mgcv", method = "REML") |> 
  set_mode("classification")

# Using weights instead: no add_formula, because the formula is in our recipe
gam_wflow_wts <- workflow(preprocessor = lsl_recipe) |> 
  add_model(gam_model,
            formula = lslpts ~ slope + cplan + cprof + elev + log10_carea) |> 
  add_case_weights(cwts) |>
  fit_resamples(lsl_folds)

gam_wflow_wts |> 
  unnest(.metrics) |>
  filter(.metric == "roc_auc")
#> # A tibble: 10 × 7
#>    splits           id     .metric .estimator .estimate .config         .notes  
#>    <list>           <chr>  <chr>   <chr>          <dbl> <chr>           <list>  
#>  1 <split [315/35]> Fold01 roc_auc binary         0.933 Preprocessor1_… <tibble>
#>  2 <split [310/40]> Fold02 roc_auc binary         0.776 Preprocessor1_… <tibble>
#>  3 <split [299/51]> Fold03 roc_auc binary         0.897 Preprocessor1_… <tibble>
#>  4 <split [323/27]> Fold04 roc_auc binary         0.8   Preprocessor1_… <tibble>
#>  5 <split [324/26]> Fold05 roc_auc binary         0.817 Preprocessor1_… <tibble>
#>  6 <split [316/34]> Fold06 roc_auc binary         0.822 Preprocessor1_… <tibble>
#>  7 <split [314/36]> Fold07 roc_auc binary         0.744 Preprocessor1_… <tibble>
#>  8 <split [322/28]> Fold08 roc_auc binary         0.840 Preprocessor1_… <tibble>
#>  9 <split [306/44]> Fold09 roc_auc binary         0.72  Preprocessor1_… <tibble>
#> 10 <split [320/30]> Fold10 roc_auc binary         0.790 Preprocessor1_… <tibble>

# calculate each fold seperately - these differ?!
for(i in 1:10){
  untidy_gam <- gam_wflow_wts$splits[[i]] |> 
    analysis() |> 
    mutate(cwts = ifelse(lslpts == 1, 1, sum(lslpts == 1) / sum(lslpts == 0))) |> 
    (\(x) {
      gam(
        lslpts ~ slope + cplan + cprof + elev + log10_carea,
        data = x,
        family = "binomial",
        weights = x$cwts,
        method = "REML"
      )
    })()

  print(
    round(
      yardstick::roc_auc_vec(
        (gam_wflow_wts$splits[[i]] |> assessment())$lslpts,
        predict(
          untidy_glm, 
          gam_wflow_wts$splits[[i]] |> assessment(), 
          type = "response"
          ),
        event_level = "second"
        )
      , 3)
  )
}
#> [1] 0.933
#> [1] 0.776
#> [1] 0.898
#> [1] 0.822
#> [1] 0.843
#> [1] 0.841
#> [1] 0.781
#> [1] 0.856
#> [1] 0.739
#> [1] 0.79

Created on 2023-05-25 with reprex v2.0.2