dmlc / xgboost

Scalable, Portable and Distributed Gradient Boosting (GBDT, GBRT or GBM) Library, for Python, R, Java, Scala, C++ and more. Runs on single machine, Hadoop, Spark, Dask, Flink and DataFlow
https://xgboost.readthedocs.io/en/stable/
Apache License 2.0
26.16k stars 8.71k forks source link

Time dependent survival #7505

Open therneau opened 2 years ago

therneau commented 2 years ago

I would like to see time-dependent covariates added for survival data. I have many clinical research studies that could use this.

The key idea that makes this possible is that for a Cox model, the first derivative can be written as the matrix product mX, where m is the vector of martingale residuals and X is the n x p matrix of predictors. Here n= number of observations in the data set and p = number of predictors. (The second derivative for a Cox model is not nearly as nice -- this makes gradient boosting with it's reliance on first derivatives a nice match. ) This should fit in well with the sparse matrix forms for X that xgboost has adopted.

The data set for time-dependent outcomes needs to be in counting process form. That is, each subject is represented by multiple rows with time1, time2, status, strata, covariates. The covariate values are those that hold true during the interval from time1 to time2, status = 1 if this interval ended with an event and 0 otherwise. Strata is a special covariate, optional, that divides the subjects into disjoint groups; it opens a lot more doors for special analytical models. This form of data is now very common in software packages such as R, SAS, Stata; so users have a lot of resources and tools for creating one. Each subject can have a unique set of intervals and time points.

A key step in all this is the routine that calculates the martingale residuals, over and over and over again. Code for this is integrated into the R survival package. Input arguments for the underlying C routine are the time1, time2, strata, and status vectors, along with the current vector of predicted values for each row of data, and two order vectors: in R they are order(strata, time1) and order(strata, time2). The routine's run time is O(2n), it has to walk down the two order vectors and keep running sums. I have 20 year's experience tuning and bomb proofing this code, some of the latter perhaps not as necessary with a gbm predictor since it won't have wild outliers for the predicted values. (Users come up with the strangest data sets.)

The naive way to do this computation is O(nd) by the way, where d= number of deaths. This is the reason that the run time for SAS phreg can be >50x longer (or more) than R:coxph for large time dependent data sets. All Cox model routines know how to do it in O(n) when there are no time dependent covariates.

I have not thought through whether further speedups are possible when the linear predictor changes only a little from one computation of m to another. This would be a very interesting think and discussion. I'd be happy to discuss the algorithm and hear ideas to make it better; it's easiest to use a picture though.

What I clearly don't know is how best to set this up in xgboost from a user's point of view. The 'response' is now 3 columns. The strata is neither fish nor fowl; not a response and not quite a predictor. I would need to cooperate with someone who knows xgboost much better. Preprocessing for the code is fairly simple: a data check that time1 < time2 for all rows, that status = 0 or 1, that strata (if present) does not put every obs into its own group (m will be all zeros in this case, and it is almost certainly a user error), and creation of the two order vectors.

Terry Therneau author of R survival

hcho3 commented 2 years ago

The second derivative for a Cox model is not nearly as nice -- this makes gradient boosting with it's reliance on first derivatives a nice match.

XGBoost actually uses second-order derivative. See https://dl.acm.org/doi/10.1145/2939672.2939785 https://arxiv.org/abs/2006.04920 for details.

hcho3 commented 2 years ago

Thanks for starting this discussion.

I have a question for you: Why is XGBoost best fit for this application? If XGBoost implements time-dependent survival analysis, how will it be better than your survival package?

If your proposed method works best with gradient boosting that uses first-order derivatives, it may be better to add this code to other gradient boosting libraries, since XGBoost uses second-order derivatives.

therneau commented 2 years ago

I want to use boosting for the same reasons as not using a linear model: large number of factors, non-linearity, missing values, etc. The xgboost code has the attraction of being an active community. The other gbm projects in R appear to be mortibund.

I'm going to have to have a longer look at gradient boosting (it's been a while since I re-read that chapter in Hastie et al), and think about what xgboost is doing differently.

therneau commented 2 years ago

A footnote: I am consumed in the next couple weeks with two grants that need to get out the door. One of them will use GBM heavily, which is making me dig in (again) to the nitty gritty details. It turns out that the second derivative with repect to the current prediction-for-subject-i will be a straightforward addition to my first derviative code. It is the second derivative wrt the betas of yhat = X beta that is more work.

I will be back once things settle down on this front. Adding time-dependent risk sets to gbm will open up a large number of time-to-event analyses, so I'm excited. The standard R gbm 2.1 series is now mothballed, i.e., bug fixes but no updates; it is what I've used for a long time. The newer gbm3 made some very unwise user interface decisions, e.g., rename all the commands that users use to look at the fit. I tried to point out the non-wisdom of this early on but was ignored, so have some serious lack of motivation wrt that code. (Version 3 also appears to be moribund -- several years since last update.)

For design (back burner): two other things to think of from a user interface point of view.

  1. When there are time intervals (time1, time2) for each observation, a given subject will often be represented as multiple observations in the data. There is then a need for an id variable that tells which lines go with who. It doesn't play a role directly in fitting the model, but is important for cross-validation or subset selection, i.e., when choosing a subset you want to keep an entire person in or out. In survival::coxph this is the 'id' option.

    1. A further class of models, which would be very useful in practice, opens up when the fit can be stratified. That is, subjects are identified as belonging to disjoint strata, with a separate intercept (baseline hazard) within strata. This is not the same as the partitioning one would do for cross-validation. The strata variable gets passed to the first/second derivative routine, but doesn't directly affect the fitting process otherwise. It's more like the idea of "sampling strata" that one encounters in survey work.