sahirbhatnagar / casebase

http://sahirbhatnagar.com/casebase/
Other
9 stars 5 forks source link

Specify foldid in fitSmoothHazard.fit #155

Open karinakwan opened 2 years ago

karinakwan commented 2 years ago

I'm interested in specifying folds for my data which can be done in cv.glmnet with no issue by passing a vector containing integers from 1 to number of folds to foldid. However, I receive the error seen below when passing the foldid argument to fitSmoothHazard.fit.

library(casebase)
#> See example usage at http://sahirbhatnagar.com/casebase/

data(support)
# Change time to years
support$d.time <- support$d.time/365.25
# Split into test and train
train_index <- sample(nrow(support), 0.95*nrow(support))
test_index <- setdiff(1:nrow(support), train_index)
train <- support[train_index,]
test <- support[test_index,]

# Create matrices for inputs
x <- model.matrix(death ~ . - d.time - aps - sps, 
                  data = train)[, -c(1)] # Remove intercept
y <- data.matrix(subset(train, select = c(d.time, death)))

# Divide into 10 folds 
foldids <- sample(1:10, size = nrow(y), replace = TRUE)

# Regularized logistic regression to estimate hazard
fitSmoothHazard.fit(x, y, family = "glmnet", time = "d.time", event = "death",
                    formula_time = ~ log(d.time),  alpha = 1, ratio = 10, 
                    standardize = TRUE, penalty.factor = c(0, rep(1, ncol(x))),
                    foldid = foldids)
#> Warning in stop(err, call. = FALSE): additional arguments ignored in stop()
#> Error in tapply(weights, foldid, sum): arguments must have same length

Created on 2022-05-20 by the reprex package (v2.0.1)

karinakwan commented 2 years ago

Update with cv.glmnet examples for binomial and Cox

library(casebase)
#> See example usage at http://sahirbhatnagar.com/casebase/
library(glmnet)
#> Loading required package: Matrix
#> Loaded glmnet 4.1-3
#> 
#> Attaching package: 'glmnet'
#> The following object is masked from 'package:casebase':
#> 
#>     prepareX

data(support)
# Change time to years
support$d.time <- support$d.time/365.25
# Split into test and train
train_index <- sample(nrow(support), 0.95*nrow(support))
test_index <- setdiff(1:nrow(support), train_index)
train <- support[train_index,]
test <- support[test_index,]

# Create matrices for inputs
x <- model.matrix(death ~ . - d.time - aps - sps, 
                  data = train)[, -c(1)] # Remove intercept
y <- data.matrix(subset(train, select = c(d.time, death)))
# Regularized logistic regression to estimate hazard

# Divide into 10 folds 
foldids <- sample(1:10, size = nrow(y), replace = TRUE)

fitSmoothHazard.fit(x, y, family = "glmnet", time = "d.time", event = "death",
                    formula_time = ~ log(d.time),  alpha = 1, ratio = 10, 
                    standardize = TRUE, penalty.factor = c(0, rep(1, ncol(x))),
                    foldid = foldids)
#> Warning in stop(err, call. = FALSE): additional arguments ignored in stop()
#> Error in tapply(weights, foldid, sum): arguments must have same length

# Works for cv.glmnet binomial and Cox
cv.glmnet(x, y, family = "binomial", foldid = foldids)
#> 
#> Call:  cv.glmnet(x = x, y = y, foldid = foldids, family = "binomial") 
#> 
#> Measure: Binomial Deviance 
#> 
#>       Lambda Index Measure       SE Nonzero
#> min 0.000897    52  0.5839 0.008086      37
#> 1se 0.006329    31  0.5908 0.007435      27

colnames(y) <- c("time", "status")
cv.glmnet(x, y, family = "cox", foldid = foldids)
#> 
#> Call:  cv.glmnet(x = x, y = y, foldid = foldids, family = "cox") 
#> 
#> Measure: Partial Likelihood Deviance 
#> 
#>      Lambda Index Measure      SE Nonzero
#> min 0.00303    46   10.81 0.07141      37
#> 1se 0.03735    19   10.88 0.06871      14

Created on 2022-05-20 by the reprex package (v2.0.1)

turgeonmaxime commented 2 years ago

@karinakwan Thanks for reporting this issue. The reason you get this error in casebase but not in glmnet is because we apply cv.glmnet to the augmented dataset, i.e. the dataset after case-base sampling. If c is the number of events, the augmented dataset will have c + c*ratio rows, and therefore that's the length foldids should be to avoid that error.

At the very least, I think casebase.fit should output a more informative error message. But I'm not sure if we can do anything more. I don't see a meaningful way of internally "fixing" a foldid of length n without skewing the proportions of each fold.