dmlc / xgboost

Scalable, Portable and Distributed Gradient Boosting (GBDT, GBRT or GBM) Library, for Python, R, Java, Scala, C++ and more. Runs on single machine, Hadoop, Spark, Dask, Flink and DataFlow
https://xgboost.readthedocs.io/en/stable/
Apache License 2.0
26.31k stars 8.73k forks source link

[R] Prediction returns infinite values for `survival:cox` #9979

Open jemus42 opened 10 months ago

jemus42 commented 10 months ago

In some (seemingly rare) cases, survival predictions end up being Inf, causing issues in downstream implementations e.g. in mlr-org/mlr3proba.

Apologies for the somewhat contrived reprex using a non-CRAN package just to load a dataset and everything, but I had a hard time tracking down this issue and reproduce it somewhat concisely. It only occured with specific resampling folds and hyperparameter configs, but at least it is reliably reproducible.

if (!requireNamespace("mlr3proba")) remotes::install_github("mlr-org/mlr3proba")
#> Loading required namespace: mlr3proba

xdf = mlr3proba::grace
xdf[, c("id")] = NULL
colnames(xdf)[1:2] = c("time", "status")

rows = c(1L, 4L, 6L, 9L, 10L, 12L, 16L, 17L, 18L, 19L, 20L, 22L, 25L,
         26L, 27L, 28L, 30L, 33L, 36L, 37L, 38L, 39L, 40L, 43L, 45L, 49L,
         50L, 51L, 52L, 53L, 54L, 55L, 58L, 60L, 61L, 62L, 63L, 64L, 67L,
         68L, 71L, 73L, 74L, 75L, 76L, 77L, 79L, 80L, 81L, 83L, 85L, 86L,
         88L, 92L, 93L, 94L, 95L, 97L, 98L, 99L, 101L, 103L, 105L, 108L,
         109L, 110L, 111L, 112L, 114L, 115L, 117L, 118L, 119L, 121L, 122L,
         125L, 126L, 130L, 132L, 133L, 134L, 135L, 138L, 139L, 141L, 144L,
         145L, 148L, 149L, 150L, 152L, 154L, 157L, 159L, 160L, 162L, 163L,
         166L, 167L, 168L, 169L, 170L, 171L, 173L, 174L, 176L, 177L, 178L,
         180L, 181L, 182L, 183L, 185L, 187L, 188L, 190L, 191L, 192L, 193L,
         194L, 195L, 196L, 197L, 201L, 202L, 203L, 204L, 205L, 206L, 208L,
         209L, 210L, 211L, 213L, 214L, 215L, 216L, 218L, 221L, 224L, 225L,
         226L, 227L, 229L, 230L, 231L, 236L, 237L, 239L, 240L, 241L, 246L,
         247L, 249L, 250L, 251L, 253L, 254L, 255L, 256L, 258L, 259L, 260L,
         261L, 262L, 263L, 265L, 266L, 267L, 268L, 269L, 270L, 271L, 272L,
         273L, 274L, 275L, 279L, 280L, 281L, 282L, 283L, 285L, 286L, 289L,
         294L, 295L, 296L, 297L, 299L, 300L, 302L, 304L, 305L, 306L, 307L,
         309L, 310L, 312L, 313L, 314L, 315L, 316L, 319L, 322L, 326L, 327L,
         328L, 332L, 333L, 334L, 335L, 337L, 338L, 340L, 341L, 342L, 345L,
         346L, 349L, 350L, 351L, 352L, 353L, 354L, 355L, 356L, 357L, 360L,
         362L, 363L, 364L, 365L, 366L, 367L, 369L, 371L, 372L, 373L, 374L,
         376L, 377L, 378L, 380L, 381L, 382L, 384L, 385L, 388L, 389L, 390L,
         391L, 392L, 393L, 395L, 396L, 397L, 398L, 399L, 400L, 402L, 403L,
         404L, 406L, 407L, 409L, 410L, 411L, 413L, 415L, 416L, 417L, 418L,
         419L, 420L, 421L, 422L, 423L, 426L, 427L, 428L, 429L, 430L, 431L,
         434L, 435L, 437L, 438L, 439L, 440L, 441L, 442L, 443L, 444L, 445L,
         447L, 448L, 451L, 453L, 454L, 455L, 457L, 458L, 459L, 461L, 462L,
         463L, 464L, 465L, 467L, 468L, 472L, 473L, 474L, 476L, 477L, 478L,
         479L, 480L, 481L, 482L, 485L, 486L, 487L, 488L, 489L, 490L, 491L,
         492L, 493L, 494L, 496L, 499L, 501L, 503L, 505L, 506L, 507L, 510L,
         511L, 512L, 513L, 515L, 516L, 517L, 518L, 519L, 521L, 524L, 525L,
         528L, 529L, 532L, 535L, 536L, 537L, 539L, 540L, 545L, 546L, 548L,
         549L, 550L, 551L, 552L, 554L, 555L, 556L, 557L, 559L, 561L, 562L,
         564L, 566L, 567L, 569L, 570L, 571L, 572L, 573L, 574L, 575L, 577L,
         580L, 581L, 583L, 584L, 586L, 587L, 588L, 589L, 590L, 591L, 592L,
         594L, 595L, 596L, 602L, 603L, 604L, 606L, 608L, 609L, 610L, 611L,
         612L, 614L, 615L, 616L, 617L, 618L, 619L, 620L, 624L, 625L, 626L,
         627L, 628L, 630L, 631L, 632L, 635L, 636L, 637L, 639L, 640L, 641L,
         644L, 645L, 648L, 649L, 651L, 656L, 657L, 659L, 660L, 661L, 662L,
         663L, 664L, 665L, 666L, 667L, 669L, 671L, 672L, 674L, 675L, 678L,
         679L, 681L, 683L, 684L, 685L, 687L, 689L, 690L, 691L, 694L, 695L,
         696L, 697L, 698L, 699L, 700L, 701L, 704L, 705L, 706L, 707L, 708L,
         709L, 710L, 711L, 712L, 713L, 714L, 716L, 717L, 718L, 719L, 721L,
         722L, 723L, 725L, 727L, 728L, 730L, 732L, 735L, 736L, 738L, 739L,
         740L, 742L, 743L, 745L, 746L, 747L, 748L, 750L, 751L, 752L, 754L,
         755L, 756L, 760L, 761L, 762L, 763L, 765L, 769L, 770L, 773L, 774L,
         775L, 776L, 778L, 779L, 780L, 781L, 782L, 785L, 786L, 787L, 789L,
         790L, 793L, 796L, 797L, 798L, 799L, 800L, 801L, 803L, 804L, 805L,
         807L, 809L, 810L, 812L, 814L, 816L, 817L, 819L, 821L, 823L, 824L,
         825L, 828L, 830L, 832L, 833L, 834L, 835L, 836L, 837L, 838L, 839L,
         843L, 844L, 846L, 847L, 848L, 849L, 851L, 852L, 853L, 854L, 856L,
         858L, 860L, 861L, 862L, 863L, 864L, 865L, 867L, 868L, 872L, 873L,
         875L, 876L, 877L, 878L, 879L, 880L, 881L, 882L, 883L, 884L, 885L,
         886L, 887L, 888L, 889L, 893L, 895L, 897L, 898L, 899L, 902L, 903L,
         905L, 906L, 908L, 909L, 910L, 916L, 917L, 920L, 921L, 922L, 924L,
         925L, 926L, 928L, 929L, 931L, 932L, 933L, 934L, 935L, 936L, 937L,
         939L, 940L, 944L, 947L, 949L, 950L, 951L, 952L, 953L, 954L, 955L,
         956L, 957L, 958L, 959L, 960L, 961L, 962L, 963L, 964L, 965L, 966L,
         969L, 971L, 973L, 974L, 976L, 977L, 979L, 980L, 983L, 984L, 985L,
         986L, 987L, 988L, 989L, 990L, 994L, 995L, 996L, 999L)

data = xdf[rows, !(names(xdf) %in% c("time", "status"))]
target = xdf[rows, names(xdf) %in% c("time", "status")]

label = target[["time"]]
status = target[["status"]]

label[status != 1] = -1L * label[status != 1]
data = xgboost::xgb.DMatrix(
  data = as.matrix(data),
  label = label)

set.seed(1)
fit <- xgboost::xgb.train(
  data = data,
  tree_method = "hist",
  booster = "gbtree",
  objective = "survival:cox",
  nrounds = 57,
  eta = 0.9687533,
  max_depth = 2,
  eval_metric = "cox-nloglik"
)

pred <- predict(fit, data)

which(is.infinite(pred))
#>  [1] 114 121 159 213 230 231 241 246 252 260 269 295 310 323 361 615 656 658
pred[which(!is.finite(pred))]
#>  [1] Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf Inf

Created on 2024-01-10 with reprex v2.0.2

Session info ``` r sessionInfo() #> R version 4.3.2 (2023-10-31) #> Platform: aarch64-apple-darwin20 (64-bit) #> Running under: macOS Sonoma 14.2.1 #> #> Matrix products: default #> BLAS: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRblas.0.dylib #> LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.11.0 #> #> locale: #> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8 #> #> time zone: Europe/Berlin #> tzcode source: internal #> #> attached base packages: #> [1] stats graphics grDevices utils datasets methods base #> #> loaded via a namespace (and not attached): #> [1] utf8_1.2.4 future_1.33.0 generics_0.1.3 #> [4] distr6_1.8.4 lattice_0.22-5 listenv_0.9.0 #> [7] digest_0.6.33 magrittr_2.0.3 evaluate_0.23 #> [10] grid_4.3.2 ooplah_0.2.0 fastmap_1.1.1 #> [13] jsonlite_1.8.7 xgboost_1.7.5.1 Matrix_1.6-3 #> [16] backports_1.4.1 survival_3.5-7 param6_0.2.4 #> [19] fansi_1.0.5 scales_1.2.1 mlr3_0.17.0 #> [22] codetools_0.2-19 palmerpenguins_0.1.1 cli_3.6.1 #> [25] rlang_1.1.2 crayon_1.5.2 mlr3viz_0.6.1 #> [28] parallelly_1.36.0 splines_4.3.2 munsell_0.5.0 #> [31] reprex_2.0.2 withr_2.5.2 yaml_2.3.7 #> [34] mlr3pipelines_0.5.0-1 tools_4.3.2 parallel_4.3.2 #> [37] uuid_1.1-1 set6_0.2.6 checkmate_2.3.0 #> [40] dplyr_1.1.3 colorspace_2.1-0 ggplot2_3.4.4 #> [43] mlr3proba_0.5.7 globals_0.16.2 vctrs_0.6.4 #> [46] R6_2.5.1 lifecycle_1.0.4 fs_1.6.3 #> [49] dictionar6_0.1.3 mlr3misc_0.13.0 pkgconfig_2.0.3 #> [52] pillar_1.9.0 gtable_0.3.4 data.table_1.14.8 #> [55] glue_1.6.2 Rcpp_1.0.11 lgr_0.4.4 #> [58] paradox_0.11.1 xfun_0.41 tibble_3.2.1 #> [61] tidyselect_1.2.0 rstudioapi_0.15.0 knitr_1.45 #> [64] htmltools_0.5.7 rmarkdown_2.25 compiler_4.3.2 ```
bblodfon commented 8 months ago

I think this is caused by the following rule: image

since the marginal prediction (or linear predictor as is usually called) can sometimes be large enough that the exponentiation can cause the Infs, eg: image

so a simple solution would be to return the marginal prediction itself (linear predictor)

hcho3 commented 8 months ago

The Cox objective might benefit from the same kind of regularization as the AFT objective, as described in Section 2.3 of https://arxiv.org/pdf/2006.04920.pdf

It also might be reasonable to clip the marginal prediction to a certain value.