RaphaelS1 / survivalmodels

Implementations of survival models in R
https://raphaels1.github.io/survivalmodels/
Other
57 stars 13 forks source link

Applying Deepsurv to predict at a unique time point #12

Closed thialam closed 3 years ago

thialam commented 3 years ago

Huge kudos and thank you to you and your package, Raphael. Really appreciate your generosity.

The project idea is to compare a few predictive survival models on a patient dataset - using a simple coxph, and thanks to your package,deepsurv and others.

The question is - Is it possible to use deepsurv to predict the survival probability of a patient at a unique time point? How can I do that? (See below my humble attempt...) I was able to do it on coxph, but not successfully using deepsurv - and deepsurv's predictions are always at weird time points, and not continuous (I wonder if the reason is they are not recognised as time?).

I hope what I say makes sense - it's getting late in my timezone and my brain is a mush 🙃.

library(survival)
library(survminer)
library(caret)
library(dplyr)
library(pec)
library(tidyr)
library(rms)
library(survivalmodels)
library(distr6)

#Dataset
breast <- survival::gbsg
breast<- breast %>%
  rename("time" = "rfstime")

smp_size <- floor(0.5 * nrow(breast))
set.seed(123)
train_ind <- sample(seq_len(nrow(breast)), size = smp_size)

train <- breast[train_ind, ]
test <- breast[-train_ind, ]

#survival objects
Btestsurv <- Surv(test$time, test$status)
Btrainsurv <- Surv(train$time, train$status)

#COXPH model
set.seed(100)
breast_coxph <- cph(Surv(time, status) ~ size + grade + nodes + pgr + er + hormon, x=T, y=T,surv=T, data = train)
breast_coxph2 <- coxph(Surv(time, status) ~ size + grade + nodes + pgr + er + hormon, data=train)

set.seed(100)
breast_coxphall <- predict(breast_coxph, newdata = test)
test300B <- survest(breast_coxph, newdata=test, times=300)$surv # survival proba at 300D using coxphmodel
test300B
#>   [1] 0.9423018 0.8660732 0.9058816 0.8947277 0.9006159 0.9497306 0.9634248
#>   [8] 0.7675983 0.9242632 0.9406859 0.9447819 0.9077486 0.9517141 0.9429295
#>  [15] 0.9343805 0.9610584 0.9452371 0.9153133 0.9338362 0.9096485 0.9360867
#>  [22] 0.9603800 0.9513375 0.9611530 0.9219953 0.9539235 0.9060825 0.9622225
#>  [29] 0.8481742 0.9128474 0.9437010 0.9426427 0.9190690 0.9563236 0.9010333
#>  [36] 0.9039430 0.9116123 0.8969054 0.8837796 0.8958040 0.9538669 0.9388902
#>  [43] 0.8904062 0.9175153 0.9194894 0.9585355 0.9310039 0.9260767 0.9480743
#>  [50] 0.9214160 0.9339232 0.9539357 0.9635000 0.9127109 0.9449737 0.9293327
#>  [57] 0.8971494 0.9342154 0.9128745 0.8915539 0.9439411 0.9112802 0.9030807
#>  [64] 0.9418157 0.9434454 0.9232242 0.9363576 0.9597493 0.8829708 0.9274692
#>  [71] 0.9254142 0.9221072 0.9375757 0.9300823 0.9402235 0.9352237 0.9595788
#>  [78] 0.9617551 0.9238205 0.9423001 0.9079410 0.9636261 0.8910659 0.9252618
#>  [85] 0.8949935 0.9272010 0.8927181 0.9561404 0.9610383 0.9249880 0.9376848
#>  [92] 0.9490415 0.9174550 0.9179218 0.8394876 0.9435569 0.8996102 0.9225521
#>  [99] 0.9373621 0.8946869 0.9244193 0.9035021 0.9299233 0.9371766 0.8754060
#> [106] 0.9642928 0.9120585 0.9403750 0.9262709 0.9394279 0.9389775 0.9572179
#> [113] 0.8940359 0.9676482 0.9554003 0.9356705 0.9396631 0.9632153 0.9025406
#> [120] 0.9436053 0.9234920 0.9631586 0.9723019 0.9489721 0.9619528 0.8709125
#> [127] 0.9391468 0.9230935 0.9556437 0.8931054 0.9345979 0.9294030 0.9651758
#> [134] 0.9147794 0.9142944 0.9416220 0.8906988 0.9379496 0.9252564 0.9499364
#> [141] 0.8583933 0.9421687 0.9417230 0.9566134 0.9435620 0.9304639 0.9625595
#> [148] 0.9442611 0.9374856 0.9556325 0.8861134 0.9555652 0.9419080 0.9244086
#> [155] 0.9404291 0.9570388 0.9401457 0.9275752 0.9647358 0.9431958 0.9217466
#> [162] 0.9243322 0.9630762 0.9255531 0.9618069 0.9462039 0.8469048 0.8676158
#> [169] 0.8860603 0.9669560 0.9426762 0.9409940 0.9012514 0.9453365 0.9724827
#> [176] 0.9556325 0.9437663 0.9250407 0.9607768 0.9522398 0.9372213 0.9423377
#> [183] 0.9113100 0.9483347 0.9168582 0.8793319 0.9653161 0.9275415 0.9405854
#> [190] 0.9404551 0.9431689 0.9599490 0.9675647 0.9265329 0.9446302 0.9431997
#> [197] 0.9462808 0.9458786 0.9510250 0.9604479 0.9493180 0.9409301 0.9494194
#> [204] 0.9616810 0.9537476 0.9658547 0.9489930 0.9207494 0.9225947 0.7879593
#> [211] 0.9399675 0.9727249 0.9476020 0.9461957 0.9205806 0.9719695 0.9245051
#> [218] 0.9645418 0.9553563 0.9450377 0.9463020 0.9770857 0.9562950 0.7965719
#> [225] 0.9613728 0.9512316 0.9615246 0.9554209 0.9522140 0.9664446 0.9428899
#> [232] 0.9662522 0.9655260 0.9374423 0.9585066 0.9570811 0.9614406 0.9528161
#> [239] 0.9385797 0.9512622 0.9553863 0.9729576 0.9394334 0.9595081 0.9687040
#> [246] 0.9574002 0.9685433 0.9623147 0.9386035 0.9531304 0.9538296 0.9670302
#> [253] 0.9734275 0.9807234 0.9296735 0.9566624 0.9587805 0.9683413 0.9761941
#> [260] 0.9532527 0.9666107 0.9543159 0.9804255 0.9747842 0.9653386 0.9678319
#> [267] 0.9651918 0.9620218 0.9716075 0.9586307 0.9736392 0.9510621 0.9617506
#> [274] 0.9321491 0.9607512 0.9447990 0.9596613 0.9110868 0.9591103 0.9429495
#> [281] 0.9721776 0.9537925 0.9640486 0.9792691 0.9693364 0.9377982 0.9620435
#> [288] 0.9647923 0.9651809 0.9741033 0.9688974 0.9750294 0.9513490 0.9782107
#> [295] 0.9801691 0.9677386 0.9833732 0.9658926 0.9846530 0.9616047 0.9810374
#> [302] 0.9642778 0.9590667 0.9750276 0.9743006 0.9676521 0.9796927 0.9717580
#> [309] 0.9716316 0.9806136 0.9833397 0.9622777 0.9666241 0.9705137 0.9733463
#> [316] 0.9800369 0.9776668 0.9734126 0.9523249 0.9812533 0.9890190 0.9744716
#> [323] 0.9622068 0.9807912 0.9648295 0.9719413 0.9829227 0.9867563 0.9814463
#> [330] 0.9785461 0.9796552 0.9753537 0.9796075 0.9810786 0.9834765 0.9832185
#> [337] 0.9951122 0.9912232 0.9941209 0.9947936 0.9963028 0.9979973 0.9979717

#Deepsurv
training.naomit <- na.omit(train)
training.naomit <- training.naomit[,-1]
testing.naomit <- na.omit(test)
testing.naomit <- testing.naomit[,-1]

set_seed(100)
breast_deepsurv <- deepsurv(data=training.naomit, frac = 0.3, activation = "relu",
                            num_nodes = c(4L, 8L, 4L, 2L), dropout = 0.1, early_stopping = TRUE, epochs = 300L,
                            batch_size = 32L, best_weights=T)

set_seed(100)
breast_deepsurvall <- predict(breast_deepsurv, newdata = testing.naomit, times=300)

#individual prediction
#predictions using deepsurv
predict(breast_deepsurv, 
        newdata=list(time=1000, status =1, age=60,meno=1, size=12,grade=2, nodes=14, pgr=0, er=0, hormon=0),
        type ="survival")
#>   0 18 29 57 67       113       114       120       160       169       171
#> 0 1  1  1  1  1 0.9927011 0.9927011 0.9853904 0.9780552 0.9701652 0.9622943
#>         177       181       184       186       205       229       238
#> 0 0.9622943 0.9543943 0.9465137 0.9465137 0.9385603 0.9385603 0.9306244
#>         272       273       281       285       286       308       329
#> 0 0.9227045 0.9227045 0.9147846 0.9068589 0.8989733 0.8910561 0.8831534
#>         336       338      350       357       358       360       369
#> 0 0.8752973 0.8674722 0.859661 0.8518767 0.8440673 0.8362679 0.8285009
#>         370       374       377       379       385       392       420
#> 0 0.8207879 0.8131003 0.8054621 0.7978388 0.7901828 0.7825771 0.7749811
#>         426       436       438       455       456       461       465
#> 0 0.7674156 0.7598683 0.7523403 0.7448649 0.7374032 0.7374032 0.7299786
#>         471       475       476       495       500       503       515
#> 0 0.7225287 0.7150739 0.7076608 0.7002699 0.6929356 0.6855662 0.6782445
#>         518      525       537       540       541       542       544
#> 0 0.6709579 0.663726 0.6564941 0.6492726 0.6492726 0.6419612 0.6274773
#>         548       550       554       557      563      566       575       594
#> 0 0.6202431 0.6130517 0.6059205 0.5988202 0.591762 0.591762 0.5847152 0.5776849
#>         595       596       598       612       631       637       646
#> 0 0.5707156 0.5707156 0.5637035 0.5567454 0.5567454 0.5567454 0.5496683
#>         650       662      670      675       679       692       698     712
#> 0 0.5426403 0.5355304 0.528477 0.528477 0.5213377 0.5213377 0.5141791 0.50705
#>        714      721       722       723       733       734       740       747
#> 0 0.499968 0.499968 0.4928954 0.4928954 0.4928954 0.4928954 0.4928954 0.4855547
#>         753       758       761       762       768       772       776
#> 0 0.4855547 0.4855547 0.4855547 0.4780994 0.4780994 0.4780994 0.4705445
#>         779       792       797       798       827       828       838
#> 0 0.4705445 0.4705445 0.4627956 0.4627956 0.4550455 0.4550455 0.4472016
#>         855       857       859       861       865       867       889
#> 0 0.4472016 0.4472016 0.4393114 0.4313456 0.4234321 0.4156241 0.4078653
#>         893       918      956      967      972      981       983       986
#> 0 0.4002151 0.4002151 0.384944 0.384944 0.384944 0.377203 0.3694675 0.3694675
#>        1013      1036      1062      1078      1080      1088      1090
#> 0 0.3694675 0.3617248 0.3617248 0.3617248 0.3538206 0.3538206 0.3458962
#>        1094      1095      1100      1105      1119      1157      1162
#> 0 0.3458962 0.3458962 0.3458962 0.3376141 0.3376141 0.3294225 0.3213751
#>        1169      1174      1177      1182      1192      1193      1195
#> 0 0.3213751 0.3133202 0.3133202 0.3133202 0.3133202 0.3050827 0.3050827
#>        1218      1219      1225      1231      1232      1253      1277
#> 0 0.2968115 0.2968115 0.2886378 0.2886378 0.2886378 0.2803524 0.2803524
#>        1279      1296      1317      1337     1343      1352      1355
#> 0 0.2720556 0.2639393 0.2639393 0.2558313 0.247816 0.2398377 0.2398377
#>        1366     1371      1387      1388      1401      1443      1463
#> 0 0.2319737 0.224173 0.2165717 0.2091242 0.2091242 0.2091242 0.2016501
#>        1472      1486      1499      1502      1560      1582      1600
#> 0 0.2016501 0.2016501 0.2016501 0.2016501 0.2016501 0.2016501 0.2016501
#>        1625      1629      1632      1645      1675      1685      1693
#> 0 0.2016501 0.2016501 0.2016501 0.2016501 0.1928176 0.1928176 0.1928176
#>        1701      1717      1721      1722      1729     1730     1735     1751
#> 0 0.1928176 0.1928176 0.1928176 0.1928176 0.1928176 0.182743 0.182743 0.182743
#>       1760      1763      1771      1781      1791      1806      1807
#> 0 0.182743 0.1718421 0.1718421 0.1718421 0.1718421 0.1599119 0.1486078
#>        1818      1833      1846      1853      1855      1858      1869
#> 0 0.1486078 0.1486078 0.1486078 0.1486078 0.1486078 0.1486078 0.1486078
#>        1878      1905      1975      1976      1977      1984      1989
#> 0 0.1486078 0.1486078 0.1346612 0.1346612 0.1346612 0.1346612 0.1201632
#>        1990      2010      2030      2039      2051      2052      2056
#> 0 0.1068751 0.1068751 0.1068751 0.0934509 0.0934509 0.0934509 0.0934509
#>        2059      2138      2153      2156      2170      2172      2175
#> 0 0.0934509 0.0934509 0.0934509 0.0934509 0.0934509 0.0934509 0.0934509
#>        2177      2233      2239      2297      2353      2388      2401
#> 0 0.0934509 0.0934509 0.0934509 0.0934509 0.0934509 0.0934509 0.0934509
#>        2438      2467      2539      2551      2659 2659.001
#> 0 0.0934509 0.0934509 0.0934509 0.0934509 0.0934509        0

#Predictions of coxph model
predict(breast_coxph2, 
        newdata=list(time=1000, status =1, size=12,grade=2, nodes=14, pgr=0, er=0, hormon=0),
        type ="survival")
#> [1] 0.4941188

Created on 2021-04-21 by the reprex package (v2.0.0)

Many many thanks! ☺️

RaphaelS1 commented 3 years ago

This is why I made distr6! Code below untested (I can't remember the output names so double check these if it fails):

breast_deepsurvall <- predict(breast_deepsurv, newdata = testing.naomit, times=300, distr6 = TRUE)
breast_deepsurvall$surv$survival(1000) # compute survival prob at T = 1000
thialam commented 3 years ago

You're a star! Works very smoothly. Thank you so so much. I'll close this now:)