mayer79 / splitTools

Light weight R package to do fast data splitting for cross-validation or train/valid/test splits
https://mayer79.github.io/splitTools/
GNU General Public License v2.0
13 stars 5 forks source link

Support splitting of survival data #13

Closed kapsner closed 2 years ago

kapsner commented 2 years ago

First of all, thank you very much for this awesome package - I am using it a lot in my work!

To enhance it a little bit, it would like to suggest to add a possibility to split survival data (especially in a stratified manner).

Currently, one needs to decide, if survival data is either split by the survival-status variable, by a grouping variable, e.g. disease-groups, or even by survival times. However, it would be very helpful to allow (stratified) splitting by all of the above mentioned. This would be very helpful in both partitioning datasets into train-validation-test and for cross-validation in order to ensure similar a distribution of survival status, survival time and (optionally) disease groups.

(As I did not find anything regarding this topic in a google search, I am not sure, if a such approach would make sense at all or if some simple approaches I am not aware of at the moment already exist. I came across this topic as I am currently analyzing survival data in a training-test split approach and I am using a cross-validation on the training-split to perform a hyperparameter optimization, observing very non-robust results, which heavily depend on the seed I set before splitting the data, with kind of a heterogeneity of the survival curves visible in the Kaplan-Meier-plots).

mayer79 commented 2 years ago

Hello and thanks for the good question!

As far as I can tell, there is no ICH/GCP guide on this sort of questions. It could even be worth to write a short article pointing out the problem of instability, which seems to be a source of bias that might become more and more relevant in medical research. Illustrated with a real data and maybe supplemented by a simulation study.

My current recommendation in your situation is as follows:

Study is an RCT

Here, I would stratify on the study group and ignore survival. To get rid of the high uncertainty from CV, do repeated CV. To get rid of the high uncertainty from the initial train/test split, nested CV is an option.

Study design is not super clean (e.g., an observational study, real-world data)

Here, stratification on survival becomes more attractive. But how to do this? In the end, survival has the two component time and status.

  1. What you can do is to create the stratification group by hand: bin the survival times and paste with the status (and with the treatment group, if the data is sufficiently large)
  2. Another option is to replace the survival object by a numeric vector of survival probabilities estimated e.g. by Nelson-Aalen on the pooled data. This is sometimes used in multivariate imputation involving survival columns. Combine with 1. to paste with the study group information.

Here is a sketch:

library(survival)

survival_strata <- function(S, type = c("interaction", "estimate"), p = seq(0, 1, len = 5)) {
  type <- match.arg(type)
  time <- S[, "time"]

  if (type == "estimate") {
    nelson <- summary(survfit(S ~ 1), censored = TRUE)
    return(nelson$surv[match(time, nelson$time)])
  }

  breaks <- unique(quantile(time, probs = p, names = FALSE))
  time_c <- cut(time, breaks, include.lowest = TRUE)  
  interaction(time_c, S[, "status"], drop = TRUE, sep = ":")
}

# Example
S <- with(aml, Surv(time, status))

table(survival_strata(S, "inter"))
survival_strata(S, "estimate")
kapsner commented 2 years ago

Thank you very much for your detailed answer and the encouragement to discuss this in an article! This is a very good idea!

My current task is in the context of an observational health study, so your second suggestion is indeed very interesting to me.

When initially opening this issue, I had some kind of "recursive" approach in mind, first splitting by group, then splitting each group by status and lastly splitting each of those by time and then recombine the fragments to get the final "survival-stratified" train-/test row-indices. I have skeched this below in comparison to your survival-strata apporach (however not sure, if a such an apporach would be statistically meaningful).

# load the dataset
dataset <- survival::colon |>
  data.table::as.data.table()

# wrapper to create right-censored Surv-objects
surv_obj <- function(dataset) {
  survival::Surv(
    time = dataset$time,
    event = dataset$status |>
      as.character() |>
      as.integer(),
    type = "right"
  ) %>%
    return()
}

group <- "rx"
seed <- 1234

# convert to factor
dataset[, ("status") := factor(get("status"))]

# Kaplan-Meier for whole dataset
surv_all <- surv_obj(dataset)
fit_all <- survival::survfit(
  surv_all ~ rx,
  data = dataset
)
km_all <- survminer::ggsurvplot(fit = fit_all)
km_all

# split dataset
splits <- splitTools::partition(
  y = dataset[, get("status")],
  p = c(train = 0.7, test = 0.3),
  type = "stratified",
  seed = seed
)

# Kaplan-Meier for train-/test-dataset
surv_train <- surv_obj(dataset[splits$train, ])
surv_test <- surv_obj(dataset[splits$test, ])
fit_train <- survival::survfit(
  surv_train ~ rx,
  data = dataset[splits$train, ]
)
fit_test <- survival::survfit(
  surv_test ~ rx,
  data = dataset[splits$test, ]
)

survminer::arrange_ggsurvplots(
  list(
    survminer::ggsurvplot(fit = fit_train),
    survminer::ggsurvplot(fit = fit_test)
  ),
  ncol = 2,
  nrow = 1,
  print = TRUE
)


# define function to recursively split by group, status and time
split_surv <- function(dataset, split_group, split_status, split_time, p) {
  group_col <- colnames(dataset)[which(colnames(dataset) == split_group)]
  status_col <- colnames(dataset)[which(colnames(dataset) == split_status)]
  time_col <- colnames(dataset)[which(colnames(dataset) == split_time)]

  # to output the ids
  train_ids <- c()
  test_ids <- c()

  data_copy <- data.table::copy(dataset)

  # add row names to original dataset
  data_copy[, ("split_row_indices") := seq_len(nrow(dataset))]

  group_split <- splitTools::partition(
    y = data_copy[, get(group_col)],
    p = p,
    type = "stratified",
    seed = seed
  )

  data_level_one_train <- data_copy[group_split$train, ]
  data_level_one_test <- data_copy[group_split$test, ]

  for (lvl_one in c("train", "test")) {

    status_split <- splitTools::partition(
      y = eval(parse(text = paste0("data_level_one_", lvl_one)))[, get(status_col)],
      p = p,
      type = "stratified",
      seed = seed
    )

    data_level_two_train <- eval(parse(text = paste0("data_level_one_", lvl_one)))[status_split$train, ]
    data_level_two_test <- eval(parse(text = paste0("data_level_one_", lvl_one)))[status_split$test, ]

    for (lvl_two in c("train", "test")) {

      time_split <- splitTools::partition(
        y = eval(parse(text = paste0("data_level_two_", lvl_two)))[, get(time_col)],
        p = p,
        n_bins = 10,
        type = "stratified",
        seed = seed
      )

      for (final_step in c("train", "test")) {
        append_name <- paste0(final_step, "_ids")
        assign(
          x = append_name,
          value = c(
            eval(parse(text = append_name)),
            eval(parse(text = paste0("data_level_two_", lvl_two)))[
              time_split[[final_step]],
              get("split_row_indices")
            ]
          )
        )
      }
    }
  }
  return(list(train = train_ids, test = test_ids))
}

# create split
splits2 <- split_surv(
  dataset = dataset,
  split_group = "rx",
  split_status = "status",
  split_time = "time",
  p = c(train = 0.7, test = 0.3)
)

# check, if all row-indices have been uniquely assigned to the splits
intersect(splits2$train, splits2$test)
#> integer(0)
intersect(splits2$test, splits2$train)
#> integer(0)

sum(length(splits2$train), length(splits2$test))
#> [1] 1858
nrow(dataset)
#> [1] 1858

# apply new splitting, plot Kaplan-Meier
surv_train2 <- surv_obj(dataset[splits2$train, ])
surv_test2 <- surv_obj(dataset[splits2$test, ])
fit_train2 <- survival::survfit(
  surv_train2 ~ rx,
  data = dataset[splits2$train, ]
)
fit_test2 <- survival::survfit(
  surv_test2 ~ rx,
  data = dataset[splits2$test, ]
)

survminer::arrange_ggsurvplots(
  list(
    survminer::ggsurvplot(fit = fit_train2),
    survminer::ggsurvplot(fit = fit_test2)
  ),
  ncol = 2,
  nrow = 1,
  print = TRUE
)


# example from @mayer79 https://github.com/mayer79/splitTools/issues/13#issuecomment-1182932665
survival_strata <- function(S, type = c("interaction", "estimate"), p = seq(0, 1, len = 5)) {
  type <- match.arg(type)
  time <- S[, "time"]

  if (type == "estimate") {
    nelson <- summary(survival::survfit(S ~ 1), censored = TRUE)
    return(nelson$surv[match(time, nelson$time)])
  }

  breaks <- unique(quantile(time, probs = p, names = FALSE))
  time_c <- cut(time, breaks, include.lowest = TRUE)
  interaction(time_c, S[, "status"], drop = TRUE, sep = ":")
}

# Example
S <- with(dataset, surv_all)

table(survival_strata(S, "inter"))
#> 
#>             [8,566]:0      (566,1.86e+03]:0 (1.86e+03,2.33e+03]:0 (2.33e+03,3.33e+03]:0 
#>                    13                    57                   420                   448 
#>             [8,566]:1      (566,1.86e+03]:1 (1.86e+03,2.33e+03]:1 (2.33e+03,3.33e+03]:1 
#>                   452                   408                    47                    13
nelson_aalen_strata <- survival_strata(S, "estimate")

# create split
splits3 <- splitTools::partition(
  y = nelson_aalen_strata,
  p = c(train = 0.7, test = 0.3),
  n_bins = 10,
  type = "stratified",
  seed = seed
)

# apply nelson_aalen_strata, plot Kaplan-Meier
surv_train3 <- surv_obj(dataset[splits3$train, ])
surv_test3 <- surv_obj(dataset[splits3$test, ])
fit_train3 <- survival::survfit(
  surv_train3 ~ rx,
  data = dataset[splits3$train, ]
)
fit_test3 <- survival::survfit(
  surv_test3 ~ rx,
  data = dataset[splits3$test, ]
)

survminer::arrange_ggsurvplots(
  list(
    survminer::ggsurvplot(fit = fit_train3),
    survminer::ggsurvplot(fit = fit_test3)
  ),
  ncol = 2,
  nrow = 1,
  print = TRUE
)

mayer79 commented 2 years ago

Great plots! I think that the recursive approach is very similat to the simple "interaction" approach: First build all combinations of group, status and survival quantiles and then use this new variable to split.

library(survival)
library(splitTools)
library(survminer)

survival_strata <- function(S, group = NULL, type = c("interaction", "estimate"), 
                            p_surv = seq(0, 1, len = 5)) {
  type <- match.arg(type)
  time <- S[, "time"]

  if (type == "estimate") {
    nelson <- summary(survfit(S ~ 1), censored = TRUE)
    return(nelson$surv[match(time, nelson$time)])
  }

  breaks <- unique(quantile(time, probs = p_surv, names = FALSE))
  time_c <- cut(time, breaks, include.lowest = TRUE)
  if (!is.null(group)) {
    return(interaction(time_c, S[, "status"], group, drop = TRUE, sep = ":"))
  }
  interaction(time_c, S[, "status"], drop = TRUE, sep = ":")
}

# Example
colon <- transform(colon, S = Surv(time, status))
strata <- survival_strata(colon$S, colon$rx, "inter")
splits <- partition(strata, c(train = 0.7, test = 0.3), type = "stratified", seed = 76)
fun <- function(split) {
  ggsurvplot(survfit(S ~ rx, data = colon, subset = split))
}
arrange_ggsurvplots(lapply(splits, fun), ncol = 2, nrow = 1, print = TRUE)

image

kapsner commented 2 years ago

Ah, now I got it :) the interaction-code is much less complicated than the recurive logic!!

Would you think it makes sense to add your suggested function to splitTools in the future?

mayer79 commented 2 years ago

What we could do: If the first argument has multiple columns, one would do stratification on all of them?

kapsner commented 2 years ago

Yes, that sounds like a good approach to me.

mayer79 commented 2 years ago

How can we prevent users from accidently passing their original data.frame instead of y? I think this is quite a common mistake when people start to use "splitTools"... . Maybe add an extra argument "allow_multiple_y = FALSE"?

kapsner commented 2 years ago

Yes, that is a good idea. Another approach could include to generate warnings, i.e. always when y has > 3 columns and additionally / or let the code fail when the interactions result in group-counts below a certain threshold (which could be defined via n_bins) i.e.

if (any(table(survival_strata(S, "inter")) < n_bins)) {
  warning(paste0("The array provided with `y` results in group-counts < ", n_bins))
}

Edit: Assuming, it is more likely to get smaller group counts the more columns are provided.

mayer79 commented 2 years ago

@kapsner : btw, I will add some of the insights to the vignette instead of changing the data type of the first argument x. This has too much implications and the risk to produce garbage is too high.

kapsner commented 2 years ago

Ah :) good to know - I am working on a PR right now in this regard - maybe there is some way that allows to have the functionality included directly into partition. Otherwise the a "preprocessing" function taking the multicolumn data matrix as input in order to prepare the "strata-vector" which can than be provided further on to the original splitting-function could also be a solution.

mayer79 commented 2 years ago

I like the idea of the preprocessing function. Then, the main functions stay as they are and someone would actively need to think about what they are doing.

kapsner commented 2 years ago

I will push an update soon. Right now, everything is included into partition, however, logic could also easily be moved to a separate function.

mayer79 commented 2 years ago

@kapsner I was thinking a bit about the situation.

In my view, the idea of stratifying by multiple columns is to build m groups, each containing observations that are quite similar regarding p features. Typically, both m and p are relatively small. Still, our original strategy will probably work only with p = 2, and even then, the number of groups m will be relatively large, depending on the arguments.

From a statistical perspective, above idea is identical to a cluster analysis: take p columns and make m clusters. This will work also for not too small p if the data is sufficiently large. The difficulty is to deal with mixed-type covariates. If only numeric and categoricals are involved (no survival), this can be roughly done by something like this:

X <- model.matrix(~ 0 + x1 + x2 + ... + xm, data = data)
X <- scale(X)
stratification_column <- factor(kmeans(X, m = m)$cluster)

There might be other good strategies, like stratified randomization (used to build treatment arms in randomized clinical trials).

splitTools should remain without additional strong dependencies. Thus, I could imagine two things:

  1. Check if there is already a dedicated package to do this task
  2. Add something like this to splitTools (using only stats:: dependency)
  3. Think about a small own package specialited to this task.
kapsner commented 2 years ago

Hi @mayer79 ,

Thanks for your thoughts on this topic. I have already tried to address some of the points you mentioned in the updated code in my fork. I have also decided to already move the code from inside the partition function to a new function called multi_strata (this name could be changed to a better one in the future).

Still, our original strategy will probably work only with p = 2, and even then, the number of groups m will be relatively large, depending on the arguments.

I agree, that the number of groups will get large quickly, depending on the length to define the number of quantile-splits for numeric variables and the number of levels for discrete variables. Therefore, messages are printed to the console in order to inform users on the number of columns they provided and the number of resulting groups m: https://github.com/kapsner/splitTools/blob/multi_strata/R/multi_strata.R#L45

On the one hand, this should make users think about if they unintentionally provided their original dataset and also make them suspicious if the number of resulting groups is very large.

Furthermore, another message informs users if a numeric variable is detected for which calculating quantiles results in less than the number of groups the user has specified with the argument num_cat, which basically defines the number of groups numeric variables should be split into.

Furthermore, I got the code generically working with p > 2, simply by creating a list that is then provided to interaction(): https://github.com/kapsner/splitTools/blob/multi_strata/R/multi_strata.R#L61

From a statistical perspective, above idea is identical to a cluster analysis: take p columns and make m clusters

I get the idea, however, I always thought that clustering relies (at least to some extend) on random numbers. As this is not the case when using interaction() I very much liked your approach posted above, which seems more straight forward to me.

However, I also think that more sophisticated approaches might be interesting too, but I also agree with you, that they probably should go into another small package (if not already existing elsewhere).

In the end, I still think that splitTools could benefit from a very simple strategy that allows to split based on multiple criteria.

What do you think of the so far implemented approach? https://github.com/kapsner/splitTools/compare/b91192518d5921941354c66057069ec77632006b...19804d9d089e95cc5b94a2abe8a90fce196b9ab9

mayer79 commented 2 years ago

Thanks for your work! I studied your code and worked a little bit on it (see below). Also see the example, which shows how nice a cluster approach would work in settings with more than two columns. Indeed it involves some randomness but this can be fixed with a with_seed() by the user. I did not yet do much testing and no docu but I can add this if you think it goes into the right direction.

# Example
ms <- multi_strata(iris)
ms2 <- multi_strata(iris, "interaction")

multi_strata <- function(df, strategy = c("kmeans", "interaction"), k = 3L) {
  strategy <- match.arg(strategy)
  stopifnot(is.data.frame(df), k >= 2L, k <= nrow(df))
  FUN <- switch(strategy, "kmeans" = .kmeans, "interaction" = .interaction)
  FUN(.good_cols(df), k = k)
}

# Strategy: kmeans
.kmeans <- function(df, k) {
  # Treat ordered as numeric
  v <- colnames(df)[vapply(df, is.ordered, FUN.VALUE = logical(1L))]
  if (length(v) >= 1L) {
    df[v] <- lapply(df[v], as.integer)
  }

  # Now the real work
  df <- scale(stats::model.matrix(~ . + 0, data = df))
  factor(stats::kmeans(df, centers = k)$cluster)
}

# Strategy: interactions across all reasonable columns
.interaction <- function(df, k) {
  v <- colnames(df)[vapply(df, is.numeric, FUN.VALUE = logical(1L))]
  if (length(v) >= 1L) {
    df[v] <- lapply(df[v], .bin_pretty, n_bins = k)
  }
  interaction(df, drop = TRUE, sep = ":")
}

# Select reasonable columns
.good_cols <- function(x) {
  ok <- vapply(
    x,
    function(v) 
      is.factor(v) || is.character(v) || is.numeric(v) || is.logical(v), 
    FUN.VALUE = logical(1L)
  )
  v <- colnames(x)[ok]
  if (length(v) == 0L) {
    stop("No numeric, factor, character or logical columns in data")
  }
  x[v]
}

# Cuts x into quantile groups (like .bin, but with nice labels)
.bin_pretty <- function(x, n_bins) {
  # +1 required as e.g. cutting a vector with breaks of length 5 results in
  # 4 groups and users define with 'num_cat' the number of categories
  # for numeric variables
  probs <- seq(0, 1, len = n_bins + 1L)
  breaks <- unique(stats::quantile(x, probs = probs, names = FALSE))
  cut(x, breaks, include.lowest = TRUE)
}
kapsner commented 2 years ago

@mayer79 Thanks for the update on this (I like your clean structured code very much - I learned some nice tricks reading you improvements to my code:smile:)

Yes, the solution goes totally into the right direction.

Regarding the documentation: I have already started on documenting the 'multi_strata' function in the fork. I could open a PR with the current state where you could add your suggested enhacements - would that be ok for you? (I have also started working on some unit-tests in the fork as well).

mayer79 commented 2 years ago

Very fine for me, thanks. Can I directly commit to your PR?

kapsner commented 2 years ago

Thanks! I have just opend the PR (#14) and also added the strategy to the roxygen documentation already.

Can I directly commit to your PR?

You mean in GitHub? Afaik you can fetch the merge-request and work locally, which, however, includes opening a new branch to commit the changes (see for example here https://stackoverflow.com/a/44992513).

mayer79 commented 2 years ago

Implemented in https://github.com/mayer79/splitTools/pull/16 and https://github.com/mayer79/splitTools/pull/14