tidymodels / censored

Parsnip wrappers for survival models
https://censored.tidymodels.org/
Other
123 stars 12 forks source link

`coxnet_train()` struggles with large number of predictors #311

Open hfrick opened 9 months ago

hfrick commented 9 months ago

Originally surfaced in https://community.rstudio.com/t/error-with-tuning-for-censored-regression/181248, here is a more minimal reprex.

library(censored)
#> Loading required package: parsnip
#> Loading required package: survival

set.seed(1)
lung <- na.omit(lung)
surv <- Surv(lung$time[1:125], lung$status[1:125])
predictors <- matrix(runif(n = 125 * 22215), nrow = 125)

dat <- cbind(
  data.frame(surv = surv),
  as.data.frame(predictors)
)

# pure glmnet model works
glmnet_fit <- glmnet::glmnet(x = as.matrix(dat[, -1]), y = dat$surv, family = "cox")

# single parsnip model fails
parsnip_fit <- fit(proportional_hazards(engine = "glmnet", penalty = 0.1), surv ~ ., data = dat)
#> Error: protect(): protection stack overflow

# reducing the number of predictors makes it work again
parsnip_fit <- fit(proportional_hazards(engine = "glmnet", penalty = 0.1), surv ~ ., data = dat[, 1:16500])

Created on 2024-02-03 with reprex v2.1.0

The choice of parnsnip interface via fit() vs fit_xy() does not matter here because they both go through censored::coxnet_train().

# 12: terms.formula(formula, specials = "strata", data = data)
# 11: stats::terms(formula, specials = "strata", data = data)
# 10: has_strata(formula, data)
# 9: remove_strata(formula, data, call = call)
# 8: censored::coxnet_train(formula = surv ~ ., data = data)
# 7: eval_tidy(e, env = envir, ...)
# 6: eval_mod(fit_call, capture = control$verbosity == 0, catch = control$catch,
#             envir = env, ...)
# 5: form_form(object = object, control = control, env = eval_env)
# 4: fit.model_spec(proportional_hazards(engine = "glmnet", penalty = 0.1),
#                   surv ~ ., data = coxdata)
# 3: NextMethod()
# 2: fit.proportional_hazards(proportional_hazards(engine = "glmnet",
#                                                  penalty = 0.1), surv ~ ., data = coxdata)
# 1: fit(proportional_hazards(engine = "glmnet", penalty = 0.1), surv ~
#          ., data = coxdata)

terms.formula() breaks at

terms <- .External(C_termsform, x, specials, data, keep.order, allowDotAsName)