therneau / survival

Survival package for R
381 stars 104 forks source link

predict.coxph wrongly predicts survival for individuals with the event #251

Closed rjjanse closed 4 months ago

rjjanse commented 4 months ago

When predicting survival probabilities using predict() on a fitted Cox model, the predicted survival probabilities for individuals with the event of interest are wrong.

Individual linear predictors are all correct. Given that the baseline hazard cannot differ between individuals, it is curious that the predicted risks are correct for individuals without the event but incorrect for individuals with the event.

From the details of predict.coxph, we read that the survival probability is not calculated using the linear predictor but the expected number of events: The survival probability for a subject is equal to exp(-expected).

I have an example based on a simple model with only age. I check the predictions manually and with the {riskRegression} package:

# Load packages
library(survival)
library(riskRegression)
library(dplyr)

# Load data
data(cancer)

# Limit follow-up time to 1 year
lung <- mutate(lung, 
               # Set status to censored if observations are after a year
               status = ifelse(time > 365.25, 1, status),
               # Cap time at 1 year
               time = ifelse(time > 365.25, 365.25, time))

# Make small Cox model
fit <- coxph(Surv(time, status) ~ age, data = lung, x = TRUE)

## Calculate by hand
# Baseline hazard at 1 year
bh <- basehaz(fit) %>%
    # Keep everything below 365 days
    filter(time <= 365.25) %>%
    # Keep last observation
    last() %>%
    # Keep baseline hazard
    `[[`("hazard")

# Linear predictor of sample (centered model)
lp_sample <- fit[["coefficients"]][["age"]] * mean(lung[["age"]])

# Individual linear predictors
lp_ind <- fit[["coefficients"]][["age"]] * lung[["age"]]

# Final individual linear predictors
lp <- lp_ind - lp_sample

## Calculate survival probability with 3 methods
# predict.coxph function from {survival}
surv_predict <- predict(fit, type = "survival")

# predictCox function from {riskRegression}
surv_predictCox <- predictCox(fit, type = "survival", newdata = lung, times = 365.25)[["survival"]]

# Calculate by hand
surv_manual <- exp(-bh) ^ exp(lp)

## Compare
comparison <- data.frame(# Predictions
                         predict = surv_predict,
                         predictCox = surv_predictCox,
                         manual = surv_manual,
                         # Linear predictor
                         lp_predict = predict(fit, type = "lp"),
                         lp_manual = lp,
                         # Status
                         status = ifelse(lung[["status"]] == 1, "censored", "dead"))

The first 10 rows of the comparison are as follows:

> print(comparison)
      predict predictCox    manual   lp_predict    lp_manual   status
1   0.4383162  0.3385880 0.3385880  0.201410107  0.201410107     dead
2   0.3770374  0.3770374 0.3770374  0.096805314  0.096805314 censored
3   0.4532652  0.4532652 0.4532652 -0.112404274 -0.112404274 censored
4   0.6789411  0.4470012 0.4470012 -0.094970142 -0.094970142     dead
5   0.4280872  0.4280872 0.4280872 -0.042667745 -0.042667745 censored
6   0.3385880  0.3385880 0.3385880  0.201410107  0.201410107 censored
7   0.4665857  0.3770374 0.3770374  0.096805314  0.096805314     dead
8   0.3764746  0.3577986 0.3577986  0.149107711  0.149107711     dead
9   0.6881966  0.4719153 0.4719153 -0.164706671 -0.164706671     dead
10  0.7660876  0.4217470 0.4217470 -0.025233613 -0.025233613     dead

Here, predict is the predict.coxph method from {survival}, predictCox comes from {riskRegression}, and manual represents my calculations by hand. The linear predictor in lp_predict comes from predict.coxph again and lp_manual was calculated by hand.

We can see that predict does not match predictCox and manual when status == "dead", but they are all equal while status == "censored". The linear predictor is correct regardless of the event.

For information my sessionInfo():

> sessionInfo()
R version 4.3.3 (2024-02-29 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19045)

Matrix products: default

locale:
[1] LC_COLLATE=Dutch_Netherlands.utf8  LC_CTYPE=Dutch_Netherlands.utf8    LC_MONETARY=Dutch_Netherlands.utf8 LC_NUMERIC=C                      
[5] LC_TIME=Dutch_Netherlands.utf8    

time zone: Europe/Amsterdam
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] riskRegression_2023.12.21 dplyr_1.1.4               magrittr_2.0.3            survival_3.5-8           

loaded via a namespace (and not attached):
 [1] gtable_0.3.4        xfun_0.43           ggplot2_3.5.0       htmlwidgets_1.6.4   lattice_0.22-5      numDeriv_2016.8-1.1 vctrs_0.6.5         tools_4.3.3        
 [9] generics_0.1.3      sandwich_3.1-0      parallel_4.3.3      tibble_3.2.1        fansi_1.0.6         cluster_2.1.6       pkgconfig_2.0.3     Matrix_1.6-5       
[17] data.table_1.15.4   checkmate_2.3.1     lifecycle_1.0.4     compiler_4.3.3      stringr_1.5.1       MatrixModels_0.5-3  munsell_0.5.1       codetools_0.2-19   
[25] SparseM_1.81        quantreg_5.97       htmltools_0.5.8.1   htmlTable_2.4.2     prodlim_2023.08.28  Formula_1.2-5       pillar_1.9.0        MASS_7.3-60.0.1    
[33] cmprsk_2.2-11       rms_6.8-0           Hmisc_5.1-2         iterators_1.0.14    multcomp_1.4-25     rpart_4.1.23        foreach_1.5.2       nlme_3.1-164       
[41] parallelly_1.37.1   lava_1.8.0          timereg_2.0.5       tidyselect_1.2.1    digest_0.6.35       polspline_1.1.24    mvtnorm_1.2-4       stringi_1.8.3      
[49] future_1.33.2       listenv_0.9.1       splines_4.3.3       fastmap_1.1.1       grid_4.3.3          colorspace_2.1-0    cli_3.6.2           base64enc_0.1-3    
[57] utf8_1.2.4          TH.data_1.1-2       future.apply_1.11.2 foreign_0.8-86      mets_1.3.4          scales_1.3.0        backports_1.4.1     rmarkdown_2.26     
[65] globals_0.16.3      nnet_7.3-19         gridExtra_2.3       zoo_1.8-12          evaluate_0.23       knitr_1.46          rlang_1.1.3         Rcpp_1.0.12        
[73] glue_1.7.0          rstudioapi_0.16.0   R6_2.5.1   
therneau commented 4 months ago
  1. You are incorrect. Type 'expected' and 'survival' from predict.coxph are the predicted cumulative hazard and survival of each subject at that subject's last follow-up time. The cumulative hazard (or expected number of deaths) plays an important role in many different computations, including the martingale residuals and relative survival.
  2. The predictCox function gives the predicted survival for each subject at a predetermined time point, in your case 365.25 days. This value plays a role in a different set of calculations. If you want this value from the survival package, use curves <- survfit(fit, newdata= lung) which will give you 228 survival curves, one per subject. Then summary(curves, time=365.25)$surv will pull off the value of each curve at 365.25 days.

Frankly, I've never found a use for predict(fit, type='surv'). I should perhaps remove that option, as you are not the first person to be confused. BTW, in your data set subject 127 is censored on day 92; the two types of prediction do not agree for this subject either.