ecpolley / SuperLearner

Current version of the SuperLearner R package
272 stars 72 forks source link

stratifyCV by iid #139

Open sgruber65 opened 3 years ago

sgruber65 commented 3 years ago

Hi. I'm sharing a little function that approximately preserves class balance within each training and validation set, even when ids are non-unique, as long as the minority class is rare (adapted from CVFolds). -- Susan stratifyCVFoldsById <- function (V, Y, id = NULL) {

1. distribute the ids that have Y = 1 in any of the rows equally among all the folds,

# 2. separately, distribute the ids that have Y = 0 for all rows equally among the folds
if (is.null(id)) id <- 1:length(Y)
case_status_by_id <- by(Y, id, sum)  # this gives n.unique results, sorted by id #
case_ids <- names(case_status_by_id)[ case_status_by_id > 0]
noncase_ids <- names(case_status_by_id)[ case_status_by_id == 0]
if (V > min(length(case_ids), length(noncase_ids))) {
    stop("number of observations in minority class is less than the number of folds")
valSet.case_ids <- split(sample(case_ids), rep(1:V, length = length(case_ids)))
valSet.noncase_ids <- split(sample(noncase_ids), rep(1:V, length = length(noncase_ids)))
validRows <- vector("list", length = V)
    names(validRows) <- paste(seq(V))
for (v in seq(V)){
    validRows[[v]] <- which(as.character(id) %in% c(valSet.case_ids[[v]],  valSet.noncase_ids[[v]]))


ledell commented 2 years ago

@sgruber65 @ecpolley I have a function that (I think) does the same thing. Mine seems a lot more complex than the one above, however, it's a drop-in replacement for CVFolds() so perhaps it can be useful? Or perhaps the SuperLearner::CVFolds() function has been updated to support both support by ID and outcome in the meantime?

Here's my version:

@ecpolley I would be happy to make a PR if useful.