grf-labs / grf

Generalized Random Forests
https://grf-labs.github.io/grf/
GNU General Public License v3.0
971 stars 250 forks source link

Allow `Y.hat` input for `causal_survival_forest()`? #1410

Open bcjaeger opened 7 months ago

bcjaeger commented 7 months ago

Hello,

Thank you for developing grf. It's great!

Would it be feasible to allow Y.hat to be an input for causal survival forests? I see from code in causal_survival_forest that a couple of intermediate values are taken from the output of survival_forest(), so it may not be straightforward to just plug in Y.hat from a separate routine. I am interested in using a forest object from aorsf::orsf()(https://github.com/ropensci/aorsf). If aorsf::orsf() could provide those intermediate outputs, would it be feasible to allow a forest object from aorsf to be used here?

erikcs commented 7 months ago

Hi @bcjaeger, thank you! Oblique forests look very interesting!

It's a good question, causal_survival_forests has many complicated nuisance components, and for simplicity we opted out of user-specified estimates. I think one route to go could be to add a specialized causal_survival_forest.fit entry point.

But before going down that road, how about stitching together your own causal_survival_forest and trying that out first? Here is a template for a causal_survival_forest.custom that you can copy directly into an R session and experiment with (using ::: to access grf's private methods).

I added 5 TODO comments where you could replace grf's survival-based estimates with your own estimates, please let me know if any of these are unclear.

bcjaeger commented 7 months ago

Awesome idea. I checkout out the TODO steps and everything looked very clear. I will try this soon and let you know how it goes. Thank you!

bcjaeger commented 6 months ago

I've made a little progress. The reprex below runs with the development version of aorsf, but will not run with the current version on CRAN. I made a small update to allow out-of-bag predictions on modified versions of the training data. I fiddled with some real data and encountered some efficiency issues with large Y.grid, so I put in two additions (I added comments that start with 'bcj addition') that were not part of the TODO's you left.

causal_survival_forest.custom <- function(
    X, Y, W, D,
    W.hat = NULL,
    target = c("RMST", "survival.probability"),
    horizon = NULL,
    failure.times = NULL,
    num.trees = 2000,
    sample.weights = NULL,
    clusters = NULL,
    equalize.cluster.weights = FALSE,
    sample.fraction = 0.5,
    mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X)),
    min.node.size = 5,
    honesty = TRUE,
    honesty.fraction = 0.5,
    honesty.prune.leaves = TRUE,
    alpha = 0.05,
    imbalance.penalty = 0,
    stabilize.splits = TRUE,
    ci.group.size = 2,
    tune.parameters = "none",
    compute.oob.predictions = TRUE,
    num.threads = NULL,
    seed = runif(1, 0, .Machine$integer.max)) {

  target <- match.arg(target)
  if (is.null(horizon) || !is.numeric(horizon) || length(horizon) != 1) {
    stop("The `horizon` argument defining the estimand is required.")
  }

  has.missing.values <- grf:::validate_X(X, allow.na = TRUE)
  grf:::validate_sample_weights(sample.weights, X)
  Y <- grf:::validate_observations(Y, X)
  W <- grf:::validate_observations(W, X)
  D <- grf:::validate_observations(D, X)
  clusters <- grf:::validate_clusters(clusters, X)
  samples.per.cluster <- grf:::validate_equalize_cluster_weights(equalize.cluster.weights, clusters, sample.weights)
  num.threads <- grf:::validate_num_threads(num.threads)
  if (any(Y < 0)) {
    stop("The event times must be non-negative.")
  }
  if (!all(D %in% c(0, 1))) {
    stop("The censor values can only be 0 or 1.")
  }
  if (sum(D) == 0) {
    stop("All observations are censored.")
  }
  if (target == "RMST") {
    # f(T) <- min(T, horizon)
    D[Y >= horizon] <- 1
    Y[Y >= horizon] <- horizon
    fY <- Y
  } else {
    # f(T) <- 1{T > horizon}
    fY <- as.numeric(Y > horizon)
  }
  if (is.null(failure.times)) {

    Y.grid <- sort(unique(Y))

    # bcj addition 1
    # large Y.grid can slow computation down.
    # consider simplifying if Y.grid is > 100 points
    if(length(Y.grid) > 100){
      Y.grid <- seq(min(Y.grid), max(Y.grid), length.out = 100)
    }

  } else if (min(Y) < min(failure.times)) {
    stop("If provided, `failure.times` should be a grid starting on or before min(Y).")
  } else {

    # bcj addition 2
    # consider simplifying for computational efficiency
    if(length(failure.times) > 100){
      failure.times.orig <- failure.times
      failure.times <- seq(min(failure.times),
                           max(failure.times),
                           length.out = 100)

      # make the subsetted failure times be a proper subset of the original
      # failure times. It would be more efficient to sample the original,
      # but this helps to ensure the grid is more evenly spaced.
      for(i in seq_along(failure.times)){
        closest_index <- which.min(abs(failure.times[i] - failure.times.orig))
        failure.times[i] <- failure.times.orig[closest_index]
      }

      # in case there were duplicates introduced.
      failure.times <- unique(failure.times)

    }

    Y.grid <- failure.times
  }
  if (length(Y.grid) <= 2) {
    stop("The number of distinct event times should be more than 2.")
  }
  if (horizon < min(Y.grid)) {
    stop("`horizon` cannot be before the first event.")
  }
  if (nrow(X) > 5000 && length(Y.grid) / nrow(X) > 0.1) {
    warning(paste0("The number of events are more than 10% of the sample size. ",
                   "To reduce the computational burden of fitting survival and ",
                   "censoring curves, consider discretizing the event values `Y` or ",
                   "supplying a coarser grid with the `failure.times` argument. "), immediate. = TRUE)
  }

  if (is.null(W.hat)) {
    forest.W <- grf::regression_forest(X, W, num.trees = max(50, num.trees / 4),
                                       sample.weights = sample.weights, clusters = clusters,
                                       equalize.cluster.weights = equalize.cluster.weights,
                                       sample.fraction = sample.fraction, mtry = mtry,
                                       min.node.size = 5, honesty = TRUE,
                                       honesty.fraction = 0.5, honesty.prune.leaves = TRUE,
                                       alpha = alpha, imbalance.penalty = imbalance.penalty,
                                       ci.group.size = 1, tune.parameters = tune.parameters,
                                       compute.oob.predictions = TRUE,
                                       num.threads = num.threads, seed = seed)
    W.hat <- predict(forest.W)$predictions
  } else if (length(W.hat) == 1) {
    W.hat <- rep(W.hat, nrow(X))
  } else if (length(W.hat) != nrow(X)) {
    stop("W.hat has incorrect length.")
  }
  W.centered <- W - W.hat

  args.nuisance <- list(failure.times = failure.times,
                        num.trees = max(50, min(num.trees / 4, 500)),
                        sample.weights = sample.weights,
                        clusters = clusters,
                        equalize.cluster.weights = equalize.cluster.weights,
                        sample.fraction = sample.fraction,
                        mtry = mtry,
                        min.node.size = 15,
                        honesty = TRUE,
                        honesty.fraction = 0.5,
                        honesty.prune.leaves = TRUE,
                        alpha = alpha,
                        prediction.type = "Nelson-Aalen", # to guarantee non-zero estimates.
                        compute.oob.predictions = TRUE,
                        num.threads = num.threads,
                        seed = seed)

  # Compute survival-based nuisance components (https://arxiv.org/abs/2001.09887)
  # m(x) relies on the survival function conditional on only X, while Q(x) relies on the conditioning (X, W).
  # Instead of fitting two separate survival forests, we can use the forest fit on (X, W) to compute m(X)
  # using the identity
  # E[f(T) | X] = e(X) E[f(T) | X, W = 1] + (1 - e(X)) E[f(T) | X, W = 0]
  # (for this to work W has to be binary).

  # orsf would throw an error if columns were unnamed
  if(is.null(colnames(X))) colnames(X) <- paste("x", seq(ncol(X)), sep = "_")

  orsf_data <- as.data.frame(cbind(y = Y, d = D, w = W, X))

  # this is to prevent aorsf from throwing an error when it
  # encounters event times of 0. I should remove this assertion
  # from aorsf, but for now:
  orsf_data$y <- pmax(orsf_data$y, .Machine$double.eps)

  # default is to use unique event times
  if(is.null(args.nuisance$failure.times)){
    args.nuisance$failure.times <- sort(unique(Y[D==1]))
  }

  # plugging in inputs from args.nuisance where possible
  sf.survival <- aorsf::orsf(
    data = orsf_data,
    formula = y + d ~ .,
    n_tree = args.nuisance$num.trees,
    weights = args.nuisance$sample.weights,
    sample_fraction = args.nuisance$sample.fraction,
    mtry = args.nuisance$mtry,
    # using min node size for both leaf stopping criteria.
    # This will usually lead to more shallow trees.
    leaf_min_obs = args.nuisance$min.node.size,
    leaf_min_events = args.nuisance$min.node.size,
    oobag_pred_type = "surv",
    oobag_pred_horizon = args.nuisance$failure.times,
    tree_seeds = round(args.nuisance$seed)
  )

  binary.W <- all(W %in% c(0, 1))

  if (binary.W) {

    # The survival function conditioning on being treated S(t, x, 1) estimated with an "S-learner".
    # Computing OOB estimates for modified training samples is not a workflow we have implemented,
    # so we do it with a manual workaround here (deleting/re-inserting precomputed predictions)

    orsf_data$w <- 1
    S1.hat <- predict(sf.survival, new_data = orsf_data, oobag = TRUE)
    orsf_data$w <- 0
    S0.hat <- predict(sf.survival, new_data = orsf_data, oobag = TRUE)
    orsf_data$w <- W

    if (target == "RMST") {
      Y.hat <- W.hat * grf:::expected_survival(S1.hat, sf.survival$pred_horizon) +
        (1 - W.hat) * grf:::expected_survival(S0.hat, sf.survival$pred_horizon)
    } else {
      horizonS.index <- findInterval(horizon, sf.survival$pred_horizon)
      if (horizonS.index == 0) {
        Y.hat <- rep(1, nrow(X))
      } else {
        Y.hat <- W.hat * S1.hat[, horizonS.index] + (1 - W.hat) * S0.hat[, horizonS.index]
      }
    }

  } else {
    # Ignoring this code branch for the simplicity's sake
    stop("Custom survival models + continuous treatment not implemented")

    # If continuous W fit a separate survival forest to estimate E[f(T) | X].
    # sf.Y <- do.call(grf::survival_forest, c(list(X = X, Y = Y, D = D), args.nuisance))
    # SY.hat <- predict(sf.Y)$predictions
    # if (target == "RMST") {
    #   Y.hat <- expected_survival(SY.hat, sf.Y$failure.times)
    # } else {
    #   horizonS.index <- findInterval(horizon, sf.survival$failure.times)
    #   if (horizonS.index == 0) {
    #     Y.hat <- rep(1, nrow(X))
    #   } else {
    #     Y.hat <- SY.hat[, horizonS.index]
    #   }
    # }
  }

  # The conditional survival function S(t, x, w) used to construct Q(x).
  S.hat <- predict(sf.survival, oobag = TRUE, pred_horizon = Y.grid)

  if (!identical(dim(S.hat), c(length(Y), length(Y.grid)))) stop("Wrong S.hat prediction dims")

  # The conditional survival function for the censoring process S_C(t, x, w).
  orsf_data$d <- 1 - D

  sf.censor <- aorsf::orsf_update(sf.survival,
                                  data = orsf_data,
                                  # default split_min_stat is about 3,
                                  # setting to 10 makes trees more shallow
                                  split_min_stat = 10,
                                  oobag_pred_horizon = Y.grid)

  C.hat <- sf.censor$pred_oobag

  if (!identical(dim(C.hat), c(length(Y), length(Y.grid)))) stop("Wrong C.hat prediction dims")

  if (target == "survival.probability") {
    # Evaluate psi up to horizon
    D[Y > horizon] <- 1
    Y[Y > horizon] <- horizon
  }

  Y.index <- findInterval(Y, Y.grid) # (invariance: Y.index > 0)
  C.Y.hat <- C.hat[cbind(seq_along(Y.index), Y.index)] # Pick out P[Ci > Yi | Xi, Wi]

  if (target == "RMST" && any(C.Y.hat <= 0.05)) {
    warning(paste("Estimated censoring probabilities go as low as:", round(min(C.Y.hat), 5),
                  "- an identifying assumption is that there exists a fixed positive constant M",
                  "such that the probability of observing an event past the maximum follow-up time ",
                  "is at least M (i.e. P(T > horizon | X) > M).",
                  "This warning appears when M is less than 0.05, at which point causal survival forest",
                  "can not be expected to deliver reliable estimates."), immediate. = TRUE)
  } else if (target == "RMST" && any(C.Y.hat < 0.2)) {
    warning(paste("Estimated censoring probabilities are lower than 0.2",
                  "- an identifying assumption is that there exists a fixed positive constant M",
                  "such that the probability of observing an event past the maximum follow-up time ",
                  "is at least M (i.e. P(T > horizon | X) > M)."))
  } else if (target == "survival.probability" && any(C.Y.hat <= 0.001)) {
    warning(paste("Estimated censoring probabilities go as low as:", round(min(C.Y.hat), 5),
                  "- forest estimates will likely be very unstable, a larger target `horizon`",
                  "is recommended."), immediate. = TRUE)
  } else if (target == "survival.probability" && any(C.Y.hat < 0.05)) {
    warning(paste("Estimated censoring probabilities are lower than 0.05",
                  "and forest estimates may not be stable. Using a smaller target `horizon`",
                  "may help."))
  }

  psi <- grf:::compute_psi(S.hat, C.hat, C.Y.hat, Y.hat, W.centered,
                           D, fY, Y.index, Y.grid, target, horizon)
  grf:::validate_observations(psi[["numerator"]], X)
  grf:::validate_observations(psi[["denominator"]], X)

  data <- grf:::create_train_matrices(X,
                                      treatment = W.centered,
                                      survival.numerator = psi[["numerator"]],
                                      survival.denominator = psi[["denominator"]],
                                      censor = D,
                                      sample.weights = sample.weights)

  args <- list(num.trees = num.trees,
               clusters = clusters,
               samples.per.cluster = samples.per.cluster,
               sample.fraction = sample.fraction,
               mtry = mtry,
               min.node.size = min.node.size,
               honesty = honesty,
               honesty.fraction = honesty.fraction,
               honesty.prune.leaves = honesty.prune.leaves,
               alpha = alpha,
               imbalance.penalty = imbalance.penalty,
               stabilize.splits = stabilize.splits,
               ci.group.size = ci.group.size,
               compute.oob.predictions = compute.oob.predictions,
               num.threads = num.threads,
               seed = seed)

  forest <- grf:::do.call.rcpp(grf:::causal_survival_train, c(data, args))
  class(forest) <- c("causal_survival_forest", "grf")
  forest[["seed"]] <- seed
  forest[["_psi"]] <- psi
  forest[["X.orig"]] <- X
  forest[["Y.orig"]] <- Y
  forest[["W.orig"]] <- W
  forest[["D.orig"]] <- D
  forest[["Y.hat"]] <- Y.hat
  forest[["W.hat"]] <- W.hat
  forest[["sample.weights"]] <- sample.weights
  forest[["clusters"]] <- clusters
  forest[["equalize.cluster.weights"]] <- equalize.cluster.weights
  forest[["has.missing.values"]] <- has.missing.values
  forest[["target"]] <- target
  forest[["horizon"]] <- horizon

  forest
}

n <- 500
p <- 5
X <- matrix(runif(n * p), n, p)
W <- rbinom(n, 1, 0.5)
horizon <- 1
failure.time <- pmin(rexp(n) * X[, 1] + W, horizon)
censor.time <- 2 * runif(n)
Y <- round(pmin(failure.time, censor.time), 2)
D <- as.integer(failure.time <= censor.time)

# grf causal survival forest
csf.orig <- grf::causal_survival_forest(X, Y, W, D, horizon = horizon, seed = 42)
grf::average_treatment_effect(csf.orig)
#>   estimate    std.err 
#> 0.59288638 0.02277714
head(predict(csf.orig))
#>   predictions
#> 1   0.5627453
#> 2   0.4707053
#> 3   0.6011034
#> 4   0.6892941
#> 5   0.5707982
#> 6   0.5513207

# your custom CS forest
csf.custom <- causal_survival_forest.custom(X, Y, W, D, horizon = horizon, seed = 42)
grf::average_treatment_effect(csf.custom)
#>   estimate    std.err 
#> 0.59504669 0.02245838
head(predict(csf.custom))
#>   predictions
#> 1   0.5710990
#> 2   0.4703588
#> 3   0.6050363
#> 4   0.6909468
#> 5   0.5722461
#> 6   0.5575252

Created on 2024-04-30 with reprex v2.1.0

erikcs commented 6 months ago

Very cool! Yes the Y.grid shouldn't be too large, we emit a warning with suggestions if that is the case, but it could be nice further down the line to add some automated grid selection (PS: in case it's useful to keep in mind when doing your modifications: the CSF code expects both S.hat and C.hat to be indexed by the same time grid)

erikcs commented 2 months ago

Hi @bcjaeger, I'm just curious if you ended up discovering any interesting differences when forming nuisance estimates with oblique random forests over plain random forests? Further down the line it could be cool if we somehow could allow advanced users to leverage your excellent aorsf package in cases where grf's survival_forest is expected to fall short : )

bcjaeger commented 2 months ago

Hi @erikcs, thanks! That is a great question. I'd like to run some tests this week and let you know if I see any differences. Are there any simulated or real data sets you'd like to see in the comparison?

erikcs commented 2 months ago

Hi, I did not have anything in particular in mind, I thought just in case you'd already tried this out! I could also imagine your oblique RSF is faster than grf's RSF.

bcjaeger commented 2 months ago

👋 Sorry for the late reply. I decided to try an experiment where I modified pbc_orsf by creating a random treatment indicator and then doubling the event time for patients who had hepato==1 and trt == 1. Both the original and custom method detected that hepato explained heterogeneity in trt, with a very small difference in the best linear projections.

library(grf)
library(aorsf)
library(magrittr)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union

causal_survival_forest.custom <- function(
  X, Y, W, D,
  W.hat = NULL,
  target = c("RMST", "survival.probability"),
  horizon = NULL,
  failure.times = NULL,
  num.trees = 2000,
  sample.weights = NULL,
  clusters = NULL,
  equalize.cluster.weights = FALSE,
  sample.fraction = 0.5,
  mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X)),
  min.node.size = 5,
  honesty = TRUE,
  honesty.fraction = 0.5,
  honesty.prune.leaves = TRUE,
  alpha = 0.05,
  imbalance.penalty = 0,
  stabilize.splits = TRUE,
  ci.group.size = 2,
  tune.parameters = "none",
  compute.oob.predictions = TRUE,
  num.threads = NULL,
  seed = runif(1, 0, .Machine$integer.max)) {

 target <- match.arg(target)
 if (is.null(horizon) || !is.numeric(horizon) || length(horizon) != 1) {
  stop("The `horizon` argument defining the estimand is required.")
 }

 has.missing.values <- grf:::validate_X(X, allow.na = TRUE)
 grf:::validate_sample_weights(sample.weights, X)
 Y <- grf:::validate_observations(Y, X)
 W <- grf:::validate_observations(W, X)
 D <- grf:::validate_observations(D, X)
 clusters <- grf:::validate_clusters(clusters, X)
 samples.per.cluster <- grf:::validate_equalize_cluster_weights(equalize.cluster.weights, clusters, sample.weights)
 num.threads <- grf:::validate_num_threads(num.threads)
 if (any(Y < 0)) {
  stop("The event times must be non-negative.")
 }
 if (!all(D %in% c(0, 1))) {
  stop("The censor values can only be 0 or 1.")
 }
 if (sum(D) == 0) {
  stop("All observations are censored.")
 }
 if (target == "RMST") {
  # f(T) <- min(T, horizon)
  D[Y >= horizon] <- 1
  Y[Y >= horizon] <- horizon
  fY <- Y
 } else {
  # f(T) <- 1{T > horizon}
  fY <- as.numeric(Y > horizon)
 }
 if (is.null(failure.times)) {

  Y.grid <- sort(unique(Y))

  # bcj addition 1
  # large Y.grid can slow computation down.
  # consider simplifying if Y.grid is > 100 points
  if(length(Y.grid) > 100){
   Y.grid <- seq(min(Y.grid), max(Y.grid), length.out = 100)
  }

 } else if (min(Y) < min(failure.times)) {
  stop("If provided, `failure.times` should be a grid starting on or before min(Y).")
 } else {

  # bcj addition 2
  # consider simplifying for computational efficiency
  if(length(failure.times) > 100){
   failure.times.orig <- failure.times
   failure.times <- seq(min(failure.times),
                        max(failure.times),
                        length.out = 100)

   # make the subsetted failure times be a proper subset of the original
   # failure times. It would be more efficient to sample the original,
   # but this helps to ensure the grid is more evenly spaced.
   for(i in seq_along(failure.times)){
    closest_index <- which.min(abs(failure.times[i] - failure.times.orig))
    failure.times[i] <- failure.times.orig[closest_index]
   }

   # in case there were duplicates introduced.
   failure.times <- unique(failure.times)

  }

  Y.grid <- failure.times
 }
 if (length(Y.grid) <= 2) {
  stop("The number of distinct event times should be more than 2.")
 }
 if (horizon < min(Y.grid)) {
  stop("`horizon` cannot be before the first event.")
 }
 if (nrow(X) > 5000 && length(Y.grid) / nrow(X) > 0.1) {
  warning(paste0("The number of events are more than 10% of the sample size. ",
                 "To reduce the computational burden of fitting survival and ",
                 "censoring curves, consider discretizing the event values `Y` or ",
                 "supplying a coarser grid with the `failure.times` argument. "), immediate. = TRUE)
 }

 if (is.null(W.hat)) {
  forest.W <- grf::regression_forest(X, W, num.trees = max(50, num.trees / 4),
                                     sample.weights = sample.weights, clusters = clusters,
                                     equalize.cluster.weights = equalize.cluster.weights,
                                     sample.fraction = sample.fraction, mtry = mtry,
                                     min.node.size = 5, honesty = TRUE,
                                     honesty.fraction = 0.5, honesty.prune.leaves = TRUE,
                                     alpha = alpha, imbalance.penalty = imbalance.penalty,
                                     ci.group.size = 1, tune.parameters = tune.parameters,
                                     compute.oob.predictions = TRUE,
                                     num.threads = num.threads, seed = seed)
  W.hat <- predict(forest.W)$predictions
 } else if (length(W.hat) == 1) {
  W.hat <- rep(W.hat, nrow(X))
 } else if (length(W.hat) != nrow(X)) {
  stop("W.hat has incorrect length.")
 }
 W.centered <- W - W.hat

 args.nuisance <- list(failure.times = failure.times,
                       num.trees = max(50, min(num.trees / 4, 500)),
                       sample.weights = sample.weights,
                       clusters = clusters,
                       equalize.cluster.weights = equalize.cluster.weights,
                       sample.fraction = sample.fraction,
                       mtry = mtry,
                       min.node.size = 15,
                       honesty = TRUE,
                       honesty.fraction = 0.5,
                       honesty.prune.leaves = TRUE,
                       alpha = alpha,
                       prediction.type = "Nelson-Aalen", # to guarantee non-zero estimates.
                       compute.oob.predictions = TRUE,
                       num.threads = num.threads,
                       seed = seed)

 # Compute survival-based nuisance components (https://arxiv.org/abs/2001.09887)
 # m(x) relies on the survival function conditional on only X, while Q(x) relies on the conditioning (X, W).
 # Instead of fitting two separate survival forests, we can use the forest fit on (X, W) to compute m(X)
 # using the identity
 # E[f(T) | X] = e(X) E[f(T) | X, W = 1] + (1 - e(X)) E[f(T) | X, W = 0]
 # (for this to work W has to be binary).

 # orsf would throw an error if columns were unnamed
 if(is.null(colnames(X))) colnames(X) <- paste("x", seq(ncol(X)), sep = "_")

 orsf_data <- as.data.frame(cbind(y = Y, d = D, w = W, X))

 # this is to prevent aorsf from throwing an error when it
 # encounters event times of 0. I should remove this assertion
 # from aorsf, but for now:
 orsf_data$y <- pmax(orsf_data$y, .Machine$double.eps)

 # default is to use unique event times
 if(is.null(args.nuisance$failure.times)){
  args.nuisance$failure.times <- sort(unique(Y[D==1]))
 }

 # plugging in inputs from args.nuisance where possible
 sf.survival <- aorsf::orsf(
  data = orsf_data,
  formula = y + d ~ .,
  n_tree = args.nuisance$num.trees,
  weights = args.nuisance$sample.weights,
  sample_fraction = args.nuisance$sample.fraction,
  mtry = args.nuisance$mtry,
  leaf_min_obs = args.nuisance$min.node.size,
  oobag_pred_type = "surv",
  oobag_pred_horizon = args.nuisance$failure.times,
  tree_seeds = round(args.nuisance$seed)
 )

 binary.W <- all(W %in% c(0, 1))

 if (binary.W) {

  # The survival function conditioning on being treated S(t, x, 1) estimated with an "S-learner".
  # Computing OOB estimates for modified training samples is not a workflow we have implemented,
  # so we do it with a manual workaround here (deleting/re-inserting precomputed predictions)

  orsf_data$w <- 1
  S1.hat <- predict(sf.survival, new_data = orsf_data, oobag = TRUE)

  orsf_data$w <- 0
  S0.hat <- predict(sf.survival, new_data = orsf_data, oobag = TRUE)

  orsf_data$w <- W

  if (target == "RMST") {

   Y.hat <- W.hat * grf:::expected_survival(S1.hat, sf.survival$pred_horizon) +
    (1 - W.hat) * grf:::expected_survival(S0.hat, sf.survival$pred_horizon)

  } else {

   horizonS.index <- findInterval(horizon, sf.survival$pred_horizon)

   if (horizonS.index == 0) {
    Y.hat <- rep(1, nrow(X))
   } else {
    Y.hat <- W.hat * S1.hat[, horizonS.index] + (1 - W.hat) * S0.hat[, horizonS.index]
   }

  }

 } else {
  # Ignoring this code branch for the simplicity's sake
  stop("Custom survival models + continuous treatment not implemented")

  # If continuous W fit a separate survival forest to estimate E[f(T) | X].
  # sf.Y <- do.call(grf::survival_forest, c(list(X = X, Y = Y, D = D), args.nuisance))
  # SY.hat <- predict(sf.Y)$predictions
  # if (target == "RMST") {
  #   Y.hat <- expected_survival(SY.hat, sf.Y$failure.times)
  # } else {
  #   horizonS.index <- findInterval(horizon, sf.survival$failure.times)
  #   if (horizonS.index == 0) {
  #     Y.hat <- rep(1, nrow(X))
  #   } else {
  #     Y.hat <- SY.hat[, horizonS.index]
  #   }
  # }
 }

 # The conditional survival function S(t, x, w) used to construct Q(x).
 S.hat <- predict(sf.survival, oobag = TRUE, pred_horizon = Y.grid)

 if (!identical(dim(S.hat), c(length(Y), length(Y.grid)))) stop("Wrong S.hat prediction dims")

 # The conditional survival function for the censoring process S_C(t, x, w).
 orsf_data$d <- 1 - D

 sf.censor <- aorsf::orsf_update(sf.survival,
                                 data = orsf_data,
                                 # default split_min_stat is about 3,
                                 # setting to 10 makes trees more shallow
                                 split_min_stat = 10,
                                 oobag_pred_horizon = Y.grid)

 C.hat <- sf.censor$pred_oobag

 if (!identical(dim(C.hat), c(length(Y), length(Y.grid)))) stop("Wrong C.hat prediction dims")

 if (target == "survival.probability") {
  # Evaluate psi up to horizon
  D[Y > horizon] <- 1
  Y[Y > horizon] <- horizon
 }

 Y.index <- findInterval(Y, Y.grid) # (invariance: Y.index > 0)
 C.Y.hat <- C.hat[cbind(seq_along(Y.index), Y.index)] # Pick out P[Ci > Yi | Xi, Wi]

 if (target == "RMST" && any(C.Y.hat <= 0.05)) {
  warning(paste("Estimated censoring probabilities go as low as:", round(min(C.Y.hat), 5),
                "- an identifying assumption is that there exists a fixed positive constant M",
                "such that the probability of observing an event past the maximum follow-up time ",
                "is at least M (i.e. P(T > horizon | X) > M).",
                "This warning appears when M is less than 0.05, at which point causal survival forest",
                "can not be expected to deliver reliable estimates."), immediate. = TRUE)
 } else if (target == "RMST" && any(C.Y.hat < 0.2)) {
  warning(paste("Estimated censoring probabilities are lower than 0.2",
                "- an identifying assumption is that there exists a fixed positive constant M",
                "such that the probability of observing an event past the maximum follow-up time ",
                "is at least M (i.e. P(T > horizon | X) > M)."))
 } else if (target == "survival.probability" && any(C.Y.hat <= 0.001)) {
  warning(paste("Estimated censoring probabilities go as low as:", round(min(C.Y.hat), 5),
                "- forest estimates will likely be very unstable, a larger target `horizon`",
                "is recommended."), immediate. = TRUE)
 } else if (target == "survival.probability" && any(C.Y.hat < 0.05)) {
  warning(paste("Estimated censoring probabilities are lower than 0.05",
                "and forest estimates may not be stable. Using a smaller target `horizon`",
                "may help."))
 }

 psi <- grf:::compute_psi(S.hat, C.hat, C.Y.hat, Y.hat, W.centered,
                          D, fY, Y.index, Y.grid, target, horizon)
 grf:::validate_observations(psi[["numerator"]], X)
 grf:::validate_observations(psi[["denominator"]], X)

 data <- grf:::create_train_matrices(X,
                                     treatment = W.centered,
                                     survival.numerator = psi[["numerator"]],
                                     survival.denominator = psi[["denominator"]],
                                     censor = D,
                                     sample.weights = sample.weights)

 args <- list(num.trees = num.trees,
              clusters = clusters,
              samples.per.cluster = samples.per.cluster,
              sample.fraction = sample.fraction,
              mtry = mtry,
              min.node.size = min.node.size,
              honesty = honesty,
              honesty.fraction = honesty.fraction,
              honesty.prune.leaves = honesty.prune.leaves,
              alpha = alpha,
              imbalance.penalty = imbalance.penalty,
              stabilize.splits = stabilize.splits,
              ci.group.size = ci.group.size,
              compute.oob.predictions = compute.oob.predictions,
              num.threads = num.threads,
              seed = seed)

 forest <- grf:::do.call.rcpp(grf:::causal_survival_train, c(data, args))
 class(forest) <- c("causal_survival_forest", "grf")
 forest[["seed"]] <- seed
 forest[["_psi"]] <- psi
 forest[["X.orig"]] <- X
 forest[["Y.orig"]] <- Y
 forest[["W.orig"]] <- W
 forest[["D.orig"]] <- D
 forest[["Y.hat"]] <- Y.hat
 forest[["W.hat"]] <- W.hat
 forest[["sample.weights"]] <- sample.weights
 forest[["clusters"]] <- clusters
 forest[["equalize.cluster.weights"]] <- equalize.cluster.weights
 forest[["has.missing.values"]] <- has.missing.values
 forest[["target"]] <- target
 forest[["horizon"]] <- horizon

 forest
}

set.seed(329)

results <- replicate(
 n = 50, 
 simplify = FALSE,
 expr = {

  data_sim <- pbc_orsf %>% 
   mutate(stage = factor(stage, ordered = FALSE)) %>% 
   select(-id) %>% 
   mutate(trt = rbinom(n(), size = 1, prob = 1/2))

  # treatment doubles event time for patients w/hepato
  index <- with(data_sim, trt == 1 & status == 1 & hepato == 1)

  data_sim$time[index] %<>% multiply_by(2)

  X <- model.matrix(~. -1L, data = select(data_sim ,-c(time, status, trt)))
  Y <- data_sim$time
  W <- data_sim$trt
  D <- data_sim$status

  w <- mean(data_sim$trt)

  horizon <- 2000

  csf.orig <- 
   causal_survival_forest(X, Y, W, D, W.hat = w, 
                          target = 'survival.probability',
                          horizon = horizon)

  results_orig <- 
   best_linear_projection(csf.orig, A = X)['hepato1', 'Estimate']

  csf.custom <- 
   causal_survival_forest.custom(X, Y, W, D, W.hat = w, 
                                 target = 'survival.probability',
                                 horizon = horizon)

  results_custom <- 
   best_linear_projection(csf.custom, A = X)['hepato1', 'Estimate']

  tibble(orig = results_orig, custom = results_custom)

 }
)

bind_rows(results) %>% 
 summarize(across(everything(), mean))
#> # A tibble: 1 × 2
#>    orig custom
#>   <dbl>  <dbl>
#> 1 0.110  0.111

Created on 2024-09-21 with reprex v2.0.2

I also ran a benchmark on computation time, and I was surprised by how fast survival_forest() is! I

library(survival)
library(grf)
library(aorsf)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union

flc <- flchain %>% 
 na.omit() %>% 
 filter(futime > 0) %>% 
 select(-chapter) %>% 
 mutate(sex = as.numeric(sex))

X <- as.matrix(select(flc, -futime, -death))
Y <- flc$futime
D <- flc$death

microbenchmark::microbenchmark(

 grf = survival_forest(X = X, Y = Y, D = D, 
                       mtry = 2,
                       num.trees = 500, 
                       min.node.size = 10),

 aorsf = orsf(flc, 
              futime + death ~ ., 
              leaf_min_obs = 10, 
              mtry = 2, 
              sample_fraction = 1/2, 
              sample_with_replacement = FALSE),

 times = 5
)
#> Unit: milliseconds
#>   expr      min       lq     mean   median       uq      max neval cld
#>    grf 151.1600 157.4182 209.1205 172.7405 271.6018 292.6821     5  a 
#>  aorsf 307.7779 349.7805 372.7963 356.5283 414.7946 435.1001     5   b

Created on 2024-09-21 with reprex v2.0.2

All things considered, I haven't found a case where aorsf improved meaningfully on the existing grf tools yet, but I'll keep thinking on it.

erikcs commented 1 month ago

Thank you, that's nice to know!