romainkp / stremr

Streamlined Estimation for Static, Dynamic and Stochastic Treatment Regimes in Longitudinal Data
MIT License
38 stars 8 forks source link

Bug in `QlearnModel` when ML-based regression fails #22

Open osofr opened 8 years ago

osofr commented 8 years ago

When doing Q learning, especially with stratification, a single algorithm might fail. In this case a fall back fitting either with glm or speedglm should be performed. For some reason this doesn't happen in this case (./tests/examples/2_building_blocks_example.R):

The following error is produced:

unable to run randomForest with h2o for: intercept only models or designmat with zero rows or  constant outcome (y) ...
Error in UseMethod("predictP1") : 
  no applicable method for 'predictP1' applied to an object of class "try-error"
In addition: Warning messages:

The code from ./tests/examples/2_building_blocks_example.R:

data(OdataNoCENS)
OdataDT <- as.data.table(OdataNoCENS, key=c(ID, t))
OdataDT[, ("N.tminus1") := shift(get("N"), n = 1L, type = "lag", fill = 1L), by = ID]
OdataDT[, ("TI.tminus1") := shift(get("TI"), n = 1L, type = "lag", fill = 1L), by = ID]
OdataDT[, ("TI.set1") := 1L]
OdataDT[, ("TI.set0") := 0L]
OData <- importData(OdataDT, ID = "ID", t = "t", covars = c("highA1c", "lastNat1", "N.tminus1"),
                    CENS = "C", TRT = "TI", MONITOR = "N", OUTCOME = "Y.tplus1")
gform_CENS <- "C ~ highA1c + lastNat1"
gform_TRT = "TI ~ CVD + highA1c + N.tminus1"
gform_MONITOR <- "N ~ 1"
stratify_CENS <- list(C=c("t < 16", "t == 16"))
require("h2o")
h2o::h2o.init(nthreads = -1)
params_TRT = list(fit.package = "h2o", fit.algorithm = "gbm", ntrees = 50,
    learn_rate = 0.05, sample_rate = 0.8, col_sample_rate = 0.8,
    balance_classes = TRUE)
params_CENS = list(fit.package = "speedglm", fit.algorithm = "glm")
params_MONITOR = list(fit.package = "speedglm", fit.algorithm = "glm")
OData <- fitPropensity(OData,
            gform_CENS = gform_CENS, stratify_CENS = stratify_CENS, params_CENS = params_CENS,
            gform_TRT = gform_TRT, params_TRT = params_TRT,
            gform_MONITOR = gform_MONITOR, params_MONITOR = params_MONITOR)
t.surv <- c(0:5)
Qforms <- rep.int("Q.kplus1 ~ CVD + highA1c + N + lastNat1 + TI + TI.tminus1", (max(t.surv)+1))
params_Q = list(fit.package = "h2o", fit.algorithm = "randomForest",
                ntrees = 100, learn_rate = 0.05, sample_rate = 0.8,
                col_sample_rate = 0.8, balance_classes = TRUE)
tmle_est <- fitTMLE(OData, t_periods = t.surv, intervened_TRT = "TI.set1",
            Qforms = Qforms, params_Q = params_Q,
            stratifyQ_by_rule = TRUE)