Open bcjaeger opened 2 months ago
Hi @bcjaeger, thank you! Oblique forests look very interesting!
It's a good question, causal_survival_forests has many complicated nuisance components, and for simplicity we opted out of user-specified estimates. I think one route to go could be to add a specialized causal_survival_forest.fit
entry point.
But before going down that road, how about stitching together your own causal_survival_forest
and trying that out first? Here is a template for a causal_survival_forest.custom that you can copy directly into an R session and experiment with (using :::
to access grf's private methods).
I added 5 TODO comments where you could replace grf's survival-based estimates with your own estimates, please let me know if any of these are unclear.
Awesome idea. I checkout out the TODO steps and everything looked very clear. I will try this soon and let you know how it goes. Thank you!
I've made a little progress. The reprex below runs with the development version of aorsf
, but will not run with the current version on CRAN. I made a small update to allow out-of-bag predictions on modified versions of the training data. I fiddled with some real data and encountered some efficiency issues with large Y.grid
, so I put in two additions (I added comments that start with 'bcj addition') that were not part of the TODO's you left.
causal_survival_forest.custom <- function(
X, Y, W, D,
W.hat = NULL,
target = c("RMST", "survival.probability"),
horizon = NULL,
failure.times = NULL,
num.trees = 2000,
sample.weights = NULL,
clusters = NULL,
equalize.cluster.weights = FALSE,
sample.fraction = 0.5,
mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X)),
min.node.size = 5,
honesty = TRUE,
honesty.fraction = 0.5,
honesty.prune.leaves = TRUE,
alpha = 0.05,
imbalance.penalty = 0,
stabilize.splits = TRUE,
ci.group.size = 2,
tune.parameters = "none",
compute.oob.predictions = TRUE,
num.threads = NULL,
seed = runif(1, 0, .Machine$integer.max)) {
target <- match.arg(target)
if (is.null(horizon) || !is.numeric(horizon) || length(horizon) != 1) {
stop("The `horizon` argument defining the estimand is required.")
}
has.missing.values <- grf:::validate_X(X, allow.na = TRUE)
grf:::validate_sample_weights(sample.weights, X)
Y <- grf:::validate_observations(Y, X)
W <- grf:::validate_observations(W, X)
D <- grf:::validate_observations(D, X)
clusters <- grf:::validate_clusters(clusters, X)
samples.per.cluster <- grf:::validate_equalize_cluster_weights(equalize.cluster.weights, clusters, sample.weights)
num.threads <- grf:::validate_num_threads(num.threads)
if (any(Y < 0)) {
stop("The event times must be non-negative.")
}
if (!all(D %in% c(0, 1))) {
stop("The censor values can only be 0 or 1.")
}
if (sum(D) == 0) {
stop("All observations are censored.")
}
if (target == "RMST") {
# f(T) <- min(T, horizon)
D[Y >= horizon] <- 1
Y[Y >= horizon] <- horizon
fY <- Y
} else {
# f(T) <- 1{T > horizon}
fY <- as.numeric(Y > horizon)
}
if (is.null(failure.times)) {
Y.grid <- sort(unique(Y))
# bcj addition 1
# large Y.grid can slow computation down.
# consider simplifying if Y.grid is > 100 points
if(length(Y.grid) > 100){
Y.grid <- seq(min(Y.grid), max(Y.grid), length.out = 100)
}
} else if (min(Y) < min(failure.times)) {
stop("If provided, `failure.times` should be a grid starting on or before min(Y).")
} else {
# bcj addition 2
# consider simplifying for computational efficiency
if(length(failure.times) > 100){
failure.times.orig <- failure.times
failure.times <- seq(min(failure.times),
max(failure.times),
length.out = 100)
# make the subsetted failure times be a proper subset of the original
# failure times. It would be more efficient to sample the original,
# but this helps to ensure the grid is more evenly spaced.
for(i in seq_along(failure.times)){
closest_index <- which.min(abs(failure.times[i] - failure.times.orig))
failure.times[i] <- failure.times.orig[closest_index]
}
# in case there were duplicates introduced.
failure.times <- unique(failure.times)
}
Y.grid <- failure.times
}
if (length(Y.grid) <= 2) {
stop("The number of distinct event times should be more than 2.")
}
if (horizon < min(Y.grid)) {
stop("`horizon` cannot be before the first event.")
}
if (nrow(X) > 5000 && length(Y.grid) / nrow(X) > 0.1) {
warning(paste0("The number of events are more than 10% of the sample size. ",
"To reduce the computational burden of fitting survival and ",
"censoring curves, consider discretizing the event values `Y` or ",
"supplying a coarser grid with the `failure.times` argument. "), immediate. = TRUE)
}
if (is.null(W.hat)) {
forest.W <- grf::regression_forest(X, W, num.trees = max(50, num.trees / 4),
sample.weights = sample.weights, clusters = clusters,
equalize.cluster.weights = equalize.cluster.weights,
sample.fraction = sample.fraction, mtry = mtry,
min.node.size = 5, honesty = TRUE,
honesty.fraction = 0.5, honesty.prune.leaves = TRUE,
alpha = alpha, imbalance.penalty = imbalance.penalty,
ci.group.size = 1, tune.parameters = tune.parameters,
compute.oob.predictions = TRUE,
num.threads = num.threads, seed = seed)
W.hat <- predict(forest.W)$predictions
} else if (length(W.hat) == 1) {
W.hat <- rep(W.hat, nrow(X))
} else if (length(W.hat) != nrow(X)) {
stop("W.hat has incorrect length.")
}
W.centered <- W - W.hat
args.nuisance <- list(failure.times = failure.times,
num.trees = max(50, min(num.trees / 4, 500)),
sample.weights = sample.weights,
clusters = clusters,
equalize.cluster.weights = equalize.cluster.weights,
sample.fraction = sample.fraction,
mtry = mtry,
min.node.size = 15,
honesty = TRUE,
honesty.fraction = 0.5,
honesty.prune.leaves = TRUE,
alpha = alpha,
prediction.type = "Nelson-Aalen", # to guarantee non-zero estimates.
compute.oob.predictions = TRUE,
num.threads = num.threads,
seed = seed)
# Compute survival-based nuisance components (https://arxiv.org/abs/2001.09887)
# m(x) relies on the survival function conditional on only X, while Q(x) relies on the conditioning (X, W).
# Instead of fitting two separate survival forests, we can use the forest fit on (X, W) to compute m(X)
# using the identity
# E[f(T) | X] = e(X) E[f(T) | X, W = 1] + (1 - e(X)) E[f(T) | X, W = 0]
# (for this to work W has to be binary).
# orsf would throw an error if columns were unnamed
if(is.null(colnames(X))) colnames(X) <- paste("x", seq(ncol(X)), sep = "_")
orsf_data <- as.data.frame(cbind(y = Y, d = D, w = W, X))
# this is to prevent aorsf from throwing an error when it
# encounters event times of 0. I should remove this assertion
# from aorsf, but for now:
orsf_data$y <- pmax(orsf_data$y, .Machine$double.eps)
# default is to use unique event times
if(is.null(args.nuisance$failure.times)){
args.nuisance$failure.times <- sort(unique(Y[D==1]))
}
# plugging in inputs from args.nuisance where possible
sf.survival <- aorsf::orsf(
data = orsf_data,
formula = y + d ~ .,
n_tree = args.nuisance$num.trees,
weights = args.nuisance$sample.weights,
sample_fraction = args.nuisance$sample.fraction,
mtry = args.nuisance$mtry,
# using min node size for both leaf stopping criteria.
# This will usually lead to more shallow trees.
leaf_min_obs = args.nuisance$min.node.size,
leaf_min_events = args.nuisance$min.node.size,
oobag_pred_type = "surv",
oobag_pred_horizon = args.nuisance$failure.times,
tree_seeds = round(args.nuisance$seed)
)
binary.W <- all(W %in% c(0, 1))
if (binary.W) {
# The survival function conditioning on being treated S(t, x, 1) estimated with an "S-learner".
# Computing OOB estimates for modified training samples is not a workflow we have implemented,
# so we do it with a manual workaround here (deleting/re-inserting precomputed predictions)
orsf_data$w <- 1
S1.hat <- predict(sf.survival, new_data = orsf_data, oobag = TRUE)
orsf_data$w <- 0
S0.hat <- predict(sf.survival, new_data = orsf_data, oobag = TRUE)
orsf_data$w <- W
if (target == "RMST") {
Y.hat <- W.hat * grf:::expected_survival(S1.hat, sf.survival$pred_horizon) +
(1 - W.hat) * grf:::expected_survival(S0.hat, sf.survival$pred_horizon)
} else {
horizonS.index <- findInterval(horizon, sf.survival$pred_horizon)
if (horizonS.index == 0) {
Y.hat <- rep(1, nrow(X))
} else {
Y.hat <- W.hat * S1.hat[, horizonS.index] + (1 - W.hat) * S0.hat[, horizonS.index]
}
}
} else {
# Ignoring this code branch for the simplicity's sake
stop("Custom survival models + continuous treatment not implemented")
# If continuous W fit a separate survival forest to estimate E[f(T) | X].
# sf.Y <- do.call(grf::survival_forest, c(list(X = X, Y = Y, D = D), args.nuisance))
# SY.hat <- predict(sf.Y)$predictions
# if (target == "RMST") {
# Y.hat <- expected_survival(SY.hat, sf.Y$failure.times)
# } else {
# horizonS.index <- findInterval(horizon, sf.survival$failure.times)
# if (horizonS.index == 0) {
# Y.hat <- rep(1, nrow(X))
# } else {
# Y.hat <- SY.hat[, horizonS.index]
# }
# }
}
# The conditional survival function S(t, x, w) used to construct Q(x).
S.hat <- predict(sf.survival, oobag = TRUE, pred_horizon = Y.grid)
if (!identical(dim(S.hat), c(length(Y), length(Y.grid)))) stop("Wrong S.hat prediction dims")
# The conditional survival function for the censoring process S_C(t, x, w).
orsf_data$d <- 1 - D
sf.censor <- aorsf::orsf_update(sf.survival,
data = orsf_data,
# default split_min_stat is about 3,
# setting to 10 makes trees more shallow
split_min_stat = 10,
oobag_pred_horizon = Y.grid)
C.hat <- sf.censor$pred_oobag
if (!identical(dim(C.hat), c(length(Y), length(Y.grid)))) stop("Wrong C.hat prediction dims")
if (target == "survival.probability") {
# Evaluate psi up to horizon
D[Y > horizon] <- 1
Y[Y > horizon] <- horizon
}
Y.index <- findInterval(Y, Y.grid) # (invariance: Y.index > 0)
C.Y.hat <- C.hat[cbind(seq_along(Y.index), Y.index)] # Pick out P[Ci > Yi | Xi, Wi]
if (target == "RMST" && any(C.Y.hat <= 0.05)) {
warning(paste("Estimated censoring probabilities go as low as:", round(min(C.Y.hat), 5),
"- an identifying assumption is that there exists a fixed positive constant M",
"such that the probability of observing an event past the maximum follow-up time ",
"is at least M (i.e. P(T > horizon | X) > M).",
"This warning appears when M is less than 0.05, at which point causal survival forest",
"can not be expected to deliver reliable estimates."), immediate. = TRUE)
} else if (target == "RMST" && any(C.Y.hat < 0.2)) {
warning(paste("Estimated censoring probabilities are lower than 0.2",
"- an identifying assumption is that there exists a fixed positive constant M",
"such that the probability of observing an event past the maximum follow-up time ",
"is at least M (i.e. P(T > horizon | X) > M)."))
} else if (target == "survival.probability" && any(C.Y.hat <= 0.001)) {
warning(paste("Estimated censoring probabilities go as low as:", round(min(C.Y.hat), 5),
"- forest estimates will likely be very unstable, a larger target `horizon`",
"is recommended."), immediate. = TRUE)
} else if (target == "survival.probability" && any(C.Y.hat < 0.05)) {
warning(paste("Estimated censoring probabilities are lower than 0.05",
"and forest estimates may not be stable. Using a smaller target `horizon`",
"may help."))
}
psi <- grf:::compute_psi(S.hat, C.hat, C.Y.hat, Y.hat, W.centered,
D, fY, Y.index, Y.grid, target, horizon)
grf:::validate_observations(psi[["numerator"]], X)
grf:::validate_observations(psi[["denominator"]], X)
data <- grf:::create_train_matrices(X,
treatment = W.centered,
survival.numerator = psi[["numerator"]],
survival.denominator = psi[["denominator"]],
censor = D,
sample.weights = sample.weights)
args <- list(num.trees = num.trees,
clusters = clusters,
samples.per.cluster = samples.per.cluster,
sample.fraction = sample.fraction,
mtry = mtry,
min.node.size = min.node.size,
honesty = honesty,
honesty.fraction = honesty.fraction,
honesty.prune.leaves = honesty.prune.leaves,
alpha = alpha,
imbalance.penalty = imbalance.penalty,
stabilize.splits = stabilize.splits,
ci.group.size = ci.group.size,
compute.oob.predictions = compute.oob.predictions,
num.threads = num.threads,
seed = seed)
forest <- grf:::do.call.rcpp(grf:::causal_survival_train, c(data, args))
class(forest) <- c("causal_survival_forest", "grf")
forest[["seed"]] <- seed
forest[["_psi"]] <- psi
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
forest[["W.orig"]] <- W
forest[["D.orig"]] <- D
forest[["Y.hat"]] <- Y.hat
forest[["W.hat"]] <- W.hat
forest[["sample.weights"]] <- sample.weights
forest[["clusters"]] <- clusters
forest[["equalize.cluster.weights"]] <- equalize.cluster.weights
forest[["has.missing.values"]] <- has.missing.values
forest[["target"]] <- target
forest[["horizon"]] <- horizon
forest
}
n <- 500
p <- 5
X <- matrix(runif(n * p), n, p)
W <- rbinom(n, 1, 0.5)
horizon <- 1
failure.time <- pmin(rexp(n) * X[, 1] + W, horizon)
censor.time <- 2 * runif(n)
Y <- round(pmin(failure.time, censor.time), 2)
D <- as.integer(failure.time <= censor.time)
# grf causal survival forest
csf.orig <- grf::causal_survival_forest(X, Y, W, D, horizon = horizon, seed = 42)
grf::average_treatment_effect(csf.orig)
#> estimate std.err
#> 0.59288638 0.02277714
head(predict(csf.orig))
#> predictions
#> 1 0.5627453
#> 2 0.4707053
#> 3 0.6011034
#> 4 0.6892941
#> 5 0.5707982
#> 6 0.5513207
# your custom CS forest
csf.custom <- causal_survival_forest.custom(X, Y, W, D, horizon = horizon, seed = 42)
grf::average_treatment_effect(csf.custom)
#> estimate std.err
#> 0.59504669 0.02245838
head(predict(csf.custom))
#> predictions
#> 1 0.5710990
#> 2 0.4703588
#> 3 0.6050363
#> 4 0.6909468
#> 5 0.5722461
#> 6 0.5575252
Created on 2024-04-30 with reprex v2.1.0
Very cool! Yes the Y.grid shouldn't be too large, we emit a warning with suggestions if that is the case, but it could be nice further down the line to add some automated grid selection (PS: in case it's useful to keep in mind when doing your modifications: the CSF code expects both S.hat and C.hat to be indexed by the same time grid)
Hello,
Thank you for developing
grf
. It's great!Would it be feasible to allow
Y.hat
to be an input for causal survival forests? I see from code incausal_survival_forest
that a couple of intermediate values are taken from the output ofsurvival_forest()
, so it may not be straightforward to just plug inY.hat
from a separate routine. I am interested in using a forest object fromaorsf::orsf()
(https://github.com/ropensci/aorsf). Ifaorsf::orsf()
could provide those intermediate outputs, would it be feasible to allow a forest object fromaorsf
to be used here?