kogalur / randomForestSRC

DOCUMENTATION:
https://www.randomforestsrc.org/
GNU General Public License v3.0
115 stars 18 forks source link

ntime and runtime #404

Closed sibipx closed 9 months ago

sibipx commented 10 months ago

To speed up calculations, and as I am not interested in predictions at all event times in the dataset, I specify ntime for both survival and competing risks models. (ntime = 1:7)

I would expect that the logrank split statistic is calculated by summing over these 7 ntimes (or closeby proxys from the provided eventimes). I would expect it to reduce the runtime as in the situation when I provide discrete eventimes in the dataset (1,2,3,4,5,6,7).

There is though a big difference between the two situations. The difference is larger for competing risks.

As I am working with large datasets and running a lot of models, working only with ntime (without rounding the eventtimes with ceiling) is not feasible at all for me.

Of course, I might have wrong expectations about what ntime does; my expectation that it would run as fast as in the situation of discrete times might be wrong. I would appreciate an explanation.

See below an example with the runtime on my Windows machine in comments.

Thank you!

library(randomForestSRC)

n <- 10000 
p <- 20

set.seed(2023)
x <- replicate(p, rnorm(n))
time <- runif(n, 0, 100)

status_CR <- round(runif(n , 1, 2)) # no censoring
status_surv <- ifelse(status_CR == 1, 1, 0) # censor other events than 1 at event time

data_CR <- data.frame(time = time, status = status_CR, x)
data_CR_round <- data.frame(time = ceiling(time), status = status_CR, x)
data_surv <- data.frame(time = time, status = status_surv, x)
data_surv_round <- data.frame(time = ceiling(time), status = status_surv, x)

# competing risks
start_time <- Sys.time()
RF_model_CR <- rfsrc(Surv(time, status) ~ ., data_CR, 
                     samptype = "swr",
                     nodesize = 100,
                     mtry = 10,
                     ntree = 500,
                     save.memory = TRUE, 
                     do.trace = FALSE,
                     importance = "none",
                     perf.type = "none",
                     ntime = 1:7, 
                     cause = 1,
                     splitrule="logrank")
message(sprintf("DONE in %s minutes.", 
                as.numeric(difftime(Sys.time(), start_time, units = "mins"))))
# DONE in 6.73131632010142 minutes.

# competing risks with rounded times 
start_time <- Sys.time()
RF_model_CR <- rfsrc(Surv(time, status) ~ ., data_CR_round, 
                     samptype = "swr",
                     nodesize = 100,
                     mtry = 10,
                     ntree = 500,
                     save.memory = TRUE, 
                     do.trace = FALSE,
                     importance = "none",
                     perf.type = "none",
                     ntime = 1:7, 
                     cause = 1,
                     splitrule="logrank")
message(sprintf("DONE in %s minutes.", 
                as.numeric(difftime(Sys.time(), start_time, units = "mins"))))
# DONE in 0.477894465128581 minutes.

# survival
start_time <- Sys.time()
RF_model_surv <- rfsrc(Surv(time, status) ~ ., data_surv, 
                       samptype = "swr",
                       nodesize = 100,
                       mtry = 10,
                       ntree = 500,
                       save.memory = TRUE, 
                       do.trace = FALSE,
                       importance = "none",
                       perf.type = "none",
                       ntime = 1:7)
message(sprintf("DONE in %s minutes.", 
                as.numeric(difftime(Sys.time(), start_time, units = "mins"))))
# DONE in 2.11460233132044 minutes.

# survival with round timepoints
start_time <- Sys.time()
RF_model_surv_round <- rfsrc(Surv(time, status) ~ ., data_surv_round, 
                             samptype = "swr",
                             nodesize = 100,
                             mtry = 10,
                             ntree = 500,
                             save.memory = TRUE, 
                             do.trace = FALSE,
                             importance = "none",
                             perf.type = "none",
                             ntime = 1:7)
message(sprintf("DONE in %s minutes.", 
                as.numeric(difftime(Sys.time(), start_time, units = "mins"))))
# DONE in 0.373596449693044 minutes.
sibipx commented 10 months ago

Small update: not only in terms of computational performance, but also in terms on functional behaviour, ntime does not seem to make a difference when passed as a vector of discrete values. It looks like it is parsed correctly in the R code but it seems that the trees are grown on all times (I get the exact same splits and the exact same predictions no matter how I set ntime).

I think I can get the behaviour I want (and a decent computational time) by manually manipulating the eventtime (ceiling) and manually applying administrative censoring in the data), but I would appreciate a clarification on the expected behaviour of ntime. Thanks!

data(pbc, package = "randomForestSRC")
pbc <- pbc[,c("days", "status", "copper")]
pbc <- na.omit(pbc)

# manipulate the event times to simplify matters
pbc$days <- ceiling(pbc$days/100) * 100
sort(unique(pbc$days))

# survival ntime 100 to 700
set.seed(2023)
RF_model_surv_7 <- rfsrc(Surv(days, status) ~ ., pbc, 
                         ntree = 1,
                         nodesize = 30,
                         ntime = (1:7) * 100)

plot(get.tree(RF_model_surv_7, 1))
RF_model_surv_7$time.interest

# survival ntime 100 to 3000
set.seed(2023)
RF_model_surv_30 <- rfsrc(Surv(days, status) ~ ., pbc, 
                          ntree = 1,
                          nodesize = 30,
                          ntime = (1:30) * 100)

plot(get.tree(RF_model_surv_30, 1))
RF_model_surv_30$time.interest

identical(RF_model_surv_7$survival.oob[,1:7], 
          RF_model_surv_30$survival.oob[,1:7])

# survival with no timepoints
set.seed(2023)
RF_model_surv_notime <- rfsrc(Surv(days, status) ~ ., pbc, 
                          ntree = 1,
                          nodesize = 30)

plot(get.tree(RF_model_surv_notime, 1))
RF_model_surv_notime$time.interest

identical(RF_model_surv_7$survival.oob[,1:7], 
          RF_model_surv_notime$survival.oob[,1:7])

# adm censoring at 7 
pbc_admcens <- pbc
pbc_admcens$status <- ifelse(pbc_admcens$days > 700, 0, pbc_admcens$status)
set.seed(2023)
RF_model_surv_7_admcens <- rfsrc(Surv(days, status) ~ ., pbc_admcens, 
                                 ntree = 1,
                                 nodesize = 30,
                                 ntime = (1:7) * 100)

plot(get.tree(RF_model_surv_7_admcens, 1))
RF_model_surv_7_admcens$time.interest

identical(RF_model_surv_7$survival.oob[,1:7], 
          RF_model_surv_7_admcens$survival.oob[,1:7])