chjackson / flexsurv

The flexsurv R package for flexible parametric survival and multi-state modelling
http://chjackson.github.io/flexsurv/
53 stars 28 forks source link

Clean up predict and broom functions and expand test coverage #122

Closed mattwarkentin closed 2 years ago

mattwarkentin commented 2 years ago

Hi @chjackson,

I believe this PR should hopefully get everything back to working/normal for predict, residuals, and the broom-compatible functions (tidy, glance, and augment), with much expanded test coverage.

I marked this PR as a draft because there still seems to be some issues with handling missing data and I didn't want to mess around with this internal machinery, so I thought I would leave that to you. I will document the issues more clearly in further comments.

mattwarkentin commented 2 years ago

So I think the current issues are based on predictions for flexsurvspline() models for type = "response" and type = "quantile" when there are missing data present:

library(flexsurv)
fit <- flexsurvspline(Surv(futime, fustat) ~ age, data = ovarian, k = 3)
ovarian_miss <- ovarian
ovarian_miss$age[[5]] <- NA

type = 'mean'

predict(fit, newdata = ovarian_miss, type = 'mean') # same as type = 'response'
#> Error in integrate(fn, start[i], t[i]): non-finite function value
11: integrate(fn, start[i], t[i]) at utils.R#158
10: rmst_generic(psurvspline, rep(Inf, nt), start = 0, matargs = c("gamma", 
        "knots"), scalarargs = c("scale", "timescale"), gamma = gamma, 
        knots = knots, beta = beta, X = X, scale = scale, timescale = timescale, 
        offset = offset)
9: (function (gamma, beta = 0, X = 0, knots = c(-10, 10), scale = "hazard", 
       timescale = "log", offset = 0) 
   {
       nt <- if (is.matrix(gamma)) 
           nrow(gamma)
       else 1
       rmst_generic(psurvspline, rep(Inf, nt), start = 0, matargs = c("gamma", 
           "knots"), scalarargs = c("scale", "timescale"), gamma = gamma, 
           knots = knots, beta = beta, X = X, scale = scale, timescale = timescale, 
           offset = offset)
   })(beta = 0, X = 0, knots = c(4.07753744390572, `25%` = 5.45570423719553, 
   `50%` = 5.88318270525789, `75%` = 6.14574211517835, 6.45833828334479
   ), scale = "hazard", timescale = "log", offset = 0, gamma = structure(c(-11.1350105158112, 
   -10.7588073823549, -12.1558250228723, -14.4358767102773, NA, 
   -13.9023494186842, -13.8141330295697, -13.3063449448222, -12.5544433682401, 
   -14.1202364164578, -13.845597861197, -15.0023132920293, -13.3454497438635, 
   -13.794102044391, -16.888532490647, -16.2181471039668, -16.9543162647352, 
   -15.9611553071519, -14.3414822153952, -16.0298105951883, -13.352602424062, 
   -10.7569104428376, -16.2157628772339, -12.7208345492076, -12.5110399998338, 
   -13.5752578753013, 2.10535935374062, 2.10535935374062, 2.10535935374062, 
    ...
8: do.call(base.fn, list(beta = beta, X = X, knots = knots, scale = scale, 
       timescale = timescale, offset = offset, gamma = gamma))
7: x$dfns$mean(...)
6: (function (t, start, ...) 
   x$dfns$mean(...))(t = c(Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, 
   Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, 
   Inf, Inf, Inf, Inf, Inf), start = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 
   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), gamma0 = c(-11.1350105158112, 
   -10.7588073823549, -12.1558250228723, -14.4358767102773, NA, 
   -13.9023494186842, -13.8141330295697, -13.3063449448222, -12.5544433682401, 
   -14.1202364164578, -13.845597861197, -15.0023132920293, -13.3454497438635, 
   -13.794102044391, -16.888532490647, -16.2181471039668, -16.9543162647352, 
   -15.9611553071519, -14.3414822153952, -16.0298105951883, -13.352602424062, 
   -10.7569104428376, -16.2157628772339, -12.7208345492076, -12.5110399998338, 
   -13.5752578753013), gamma1 = c(2.10535935374062, 2.10535935374062, 
   2.10535935374062, 2.10535935374062, 2.10535935374062, 2.10535935374062, 
   2.10535935374062, 2.10535935374062, 2.10535935374062, 2.10535935374062, 
   2.10535935374062, 2.10535935374062, 2.10535935374062, 2.10535935374062, 
   2.10535935374062, 2.10535935374062, 2.10535935374062, 2.10535935374062, 
   2.10535935374062, 2.10535935374062, 2.10535935374062, 2.10535935374062, 
   2.10535935374062, 2.10535935374062, 2.10535935374062, 2.10535935374062
   ), gamma2 = c(5.896213713062, 5.896213713062, 5.896213713062, 
   5.896213713062, 5.896213713062, 5.896213713062, 5.896213713062, 
    ...
5: do.call(fn, args) at summary.flexsurvreg.R#167
4: summary.flexsurvreg(object = object, newdata = newdata, type = stype, 
       quantiles = p, start = start, ci = conf.int, cl = conf.level, 
       se = se.fit, tidy = TRUE, na.action = na.pass, cross = TRUE)
3: summary(object = object, newdata = newdata, type = stype, quantiles = p, 
       start = start, ci = conf.int, cl = conf.level, se = se.fit, 
       tidy = TRUE, na.action = na.pass, cross = TRUE) at predict.flexsurvreg.R#188
2: predict.flexsurvreg(fit, newdata = ovarian_miss, type = "mean")
1: predict(fit, newdata = ovarian_miss, type = "mean")

type = 'quantile'

predict(fit, newdata = ovarian_miss, type = 'quantile', p = 0.5)
#> Error in rstpm2::vuniroot(h, interval, tol = .Machine$double.eps, extendInt = "yes", : f.upper = f(upper) is NA at least once
12: stop("f.upper = f(upper) is NA at least once")
11: rstpm2::vuniroot(h, interval, tol = .Machine$double.eps, extendInt = "yes", 
        maxiter = 10000) at utils.R#272
10: qgeneric(psurvspline, p = p, matargs = c("gamma", "knots"), scalarargs = c("scale", 
        "timescale"), gamma = gamma, beta = beta, X = X, knots = knots, 
        scale = scale, timescale = timescale, offset = offset) at <text>#7
9: (function (p, gamma, beta = 0, X = 0, knots = c(-10, 10), scale = "hazard", 
       timescale = "log", offset = 0, lower.tail = TRUE, log.p = FALSE) 
   {
       if (log.p) 
           p <- exp(p)
       if (!lower.tail) 
           p <- 1 - p
       qgeneric(psurvspline, p = p, matargs = c("gamma", "knots"), 
           scalarargs = c("scale", "timescale"), gamma = gamma, 
           beta = beta, X = X, knots = knots, scale = scale, timescale = timescale, 
           offset = offset)
   })(p = c(0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 
   0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 
   0.5, 0.5), beta = 0, X = 0, knots = c(4.07753744390572, `25%` = 5.45570423719553, 
   `50%` = 5.88318270525789, `75%` = 6.14574211517835, 6.45833828334479
   ), scale = "hazard", timescale = "log", offset = 0, lower.tail = TRUE, 
       log.p = FALSE, gamma = structure(c(-11.1350105158112, -10.7588073823549, 
       -12.1558250228723, -14.4358767102773, NA, -13.9023494186842, 
       -13.8141330295697, -13.3063449448222, -12.5544433682401, 
       -14.1202364164578, -13.845597861197, -15.0023132920293, -13.3454497438635, 
    ...
8: do.call(base.fn, list(p = p, beta = beta, X = X, knots = knots, 
       scale = scale, timescale = timescale, offset = offset, lower.tail = lower.tail, 
       log.p = log.p, gamma = gamma))
7: x$dfns$q(med_from_start, ...) at summary.flexsurvreg.R#357
6: (function (t, start, ...) 
   {
       start_p = 1 - x$dfns$p(start, ...)
       med_from_start = start_p * t
       ret = x$dfns$q(med_from_start, ...)
   })(t = c(0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 
   0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 
   0.5, 0.5), start = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), gamma0 = c(-11.1350105158112, 
   -10.7588073823549, -12.1558250228723, -14.4358767102773, NA, 
   -13.9023494186842, -13.8141330295697, -13.3063449448222, -12.5544433682401, 
   -14.1202364164578, -13.845597861197, -15.0023132920293, -13.3454497438635, 
   -13.794102044391, -16.888532490647, -16.2181471039668, -16.9543162647352, 
   -15.9611553071519, -14.3414822153952, -16.0298105951883, -13.352602424062, 
   -10.7569104428376, -16.2157628772339, -12.7208345492076, -12.5110399998338, 
   -13.5752578753013), gamma1 = c(2.10535935374062, 2.10535935374062, 
   2.10535935374062, 2.10535935374062, 2.10535935374062, 2.10535935374062, 
   2.10535935374062, 2.10535935374062, 2.10535935374062, 2.10535935374062, 
   2.10535935374062, 2.10535935374062, 2.10535935374062, 2.10535935374062, 
   2.10535935374062, 2.10535935374062, 2.10535935374062, 2.10535935374062, 
    ...
5: do.call(fn, args) at summary.flexsurvreg.R#167
4: summary.flexsurvreg(object = object, newdata = newdata, type = stype, 
       quantiles = p, start = start, ci = conf.int, cl = conf.level, 
       se = se.fit, tidy = TRUE, na.action = na.pass, cross = TRUE)
3: summary(object = object, newdata = newdata, type = stype, quantiles = p, 
       start = start, ci = conf.int, cl = conf.level, se = se.fit, 
       tidy = TRUE, na.action = na.pass, cross = TRUE) at predict.flexsurvreg.R#188
2: predict.flexsurvreg(fit, newdata = ovarian_miss, type = "quantile", 
       p = 0.5)
1: predict(fit, newdata = ovarian_miss, type = "quantile", p = 0.5)
mattwarkentin commented 2 years ago

For the issue with rmst_generic(), the missing data isn't actually detectable at: https://github.com/chjackson/flexsurv-dev/blob/01a79b31e42d0b5fe24a8c81d5ff314ffdffbbe9/R/utils.R#L158

I believe the missing shows up in fargs_mat and is passed on to pdargs and ultimately fn, so the missing isn't confronted until integrate() is called. So I guess the missingness needs to be handled earlier in this for loop, I think... https://github.com/chjackson/flexsurv-dev/blob/01a79b31e42d0b5fe24a8c81d5ff314ffdffbbe9/R/utils.R#L149-L157

mattwarkentin commented 2 years ago

For qgeneric() the issue is similar, the missing shows up in args.mat and is passed on to h() which is then passed on to rstpm2::vuniroot: https://github.com/chjackson/flexsurv-dev/blob/01a79b31e42d0b5fe24a8c81d5ff314ffdffbbe9/R/utils.R#L264-L272

chjackson commented 2 years ago

Thanks for all this. I think I can see how to fix these remaining problems, both in qgeneric and rmst_generic, by checking for NAs in args and args_mat before calling integrate or vuniroot. If you mark your PR as final I'll merge it and implement those fixes.

mattwarkentin commented 2 years ago

I have transferred this PR from draft to ready for review. I think if all of my tests pass after your upcoming changes then we should be good to go!