chjackson / flexsurv

The flexsurv R package for flexible parametric survival and multi-state modelling
http://chjackson.github.io/flexsurv/
53 stars 28 forks source link

Standardised survival #102

Closed mikesweeting closed 1 year ago

mikesweeting commented 2 years ago

Hi @chjackson, great package! I've found it extremely useful.

What are your thoughts about adding functionality to flexsurv to allow standardised survival (and possibly other standardised metrics such as RMST) to be calculated?

I'd be keen to see something similar to Stata's standsurv, where it is possible to calculate standardised survival with certain covariates fixed to different counterfactual values, and average treatment effects (contrasts) can then be derived. See https://pclambert.net/software/standsurv/standardized_survival/

Could this be done either as a wrapper function of summary_flexsurvreg, or perhaps included as additional arguments to summary.flexsurvreg itself (i.e. standardise = T)? Standardisation would be done by averaging over the predictions in newdata.

Happy to help out, I've got a proof-of-concept wrapper script I can contribute.

chjackson commented 2 years ago

Hi Mike - I don't get much time these days to add new features that I won't use personally, but happy to consider contributions. It sounds like a good start would be a wrapper around summary.flexsurvreg, if the idea is to integrate the output of summary.flexsurvreg(..., newdata) over the covariate values given in newdata. Perhaps your code could be posted here.

mikesweeting commented 2 years ago

Hi Chris, firstly apologies, this is a rough function that needs a lot of work to improve it. But if you feel it has some potential I would be happy to discuss further and can put some more time into it.

Confidence intervals / standard errors may be difficult to implement with parametric bootstrap if newdata has many rows. Perhaps a delta method approach may be possible, as done in standsurvin Stata?

standsurv.flexsurvreg <- function(object, newdata = NULL, at = list(), atreference = 1, type = "survival", t = NULL,
                                  ci = FALSE, contrast = NULL) {
  x <- object

  ## Add checks
  ## Currently restricted to survival or rmst 
  type <- match.arg(type, c("survival", "rmst"))
  contrast <- match.arg(contrast, c("difference", "ratio"))
  ## Currently does not calculate CIs 
  ci <- match.arg(ci, "FALSE")

  ## Check that at is a list and that all elements of at are lists
  if(!is.list(at)){
    stop("'at' must be at list")
  }
  if(any(!sapply(at, is.list))){
    stop("All elements of 'at' must be lists")
  }

  ## Contrast numbers
  cnums <- (1:length(at))[-atreference]

  ## Standardise over fitted dataset by default
  if(is.null(newdata)){
    data <- model.frame(x)
  } else{
    data <- newdata
  }

  ## If at is not specified then no further manipulation of data is required, we standardise over original or passes dataset
  for(i in 1:length(at)){
    dat <- data
    covs <- at[[i]]
    covnames <- names(covs)
    for (j in 1:length(covnames)) dat[, covnames[j]] <- covs[j]
    pred <- summary(object, type = type, tidy = T, newdata=dat, t=t, ci=ci)
    predsum <- pred %>% group_by(time) %>% summarise("at{i}" := mean(est))

    if(i == 1) {
      standpred <- predsum
    } else {
      standpred <- standpred %>% inner_join(predsum, by ="time")
    }

  }

  if(contrast == "difference"){
    for(i in cnums){
      standpred <- standpred %>% mutate("contrast{i}_{atreference}" := .data[[paste0("at", i)]] - .data[[paste0("at", atreference)]])
    }
  }
  if(contrast == "ratio"){
    for(i in cnums){
      standpred <- standpred %>% mutate("contrast{i}_{atreference}" := .data[[paste0("at", i)]] / .data[[paste0("at", atreference)]])
    }
  }

  standpred
}

Example usage:

library(flexsurv)
## Using the bc dataset
## Create a continuous variable "age", which has some association with rectime
set.seed(236236)
bc$age <- rnorm(dim(bc)[1], mean = 65 - bc$recyrs, sd = 5)
spl2_age <- flexsurvspline(Surv(recyrs, censrec) ~ group+age, data=bc, k=2, scale="hazard")
standsurv.flexsurvreg(spl2_age, 
                                            at = list(list(group="Good"), list(group="Medium"), list(group="Poor")), 
                                            t=seq(0,7, length=100),
                                            contrast = "difference")
## # A tibble: 100 x 6
##      time   at1   at2   at3 contrast2_1 contrast3_1
##     <dbl> <dbl> <dbl> <dbl>       <dbl>       <dbl>
##  1 0      1     1     1      0            0        
##  2 0.0707 1.00  1.00  1.00  -0.00000907  -0.0000275
##  3 0.141  1.00  1.00  1.00  -0.0000988   -0.000299 
##  4 0.212  1.00  0.999 0.999 -0.000399    -0.00121  
##  5 0.283  0.999 0.998 0.996 -0.00107     -0.00323  
##  6 0.354  0.998 0.996 0.992 -0.00225     -0.00680  
##  7 0.424  0.997 0.993 0.985 -0.00406     -0.0122   
##  8 0.495  0.995 0.989 0.975 -0.00656     -0.0197   
##  9 0.566  0.993 0.983 0.964 -0.00978     -0.0292   
## 10 0.636  0.990 0.976 0.949 -0.0137      -0.0408   
## # ... with 90 more rows