Genentech / jmpost

https://genentech.github.io/jmpost/
17 stars 4 forks source link

Weibull survival model #194

Closed ya0 closed 1 year ago

ya0 commented 1 year ago

When trying to implement joint models i had some troubles with the survival part. Spesifically the weibull survival model.

I tryed implementing a survival only model and still have some troubles with the parameter estimation.

Here is a minimal reproducible code

library(rstan)
library(cmdstanr)
library(jmpost)

# example: breastcancer data from flexsurv package
library(flexsurv)

# we analyze the group with label "Good"
g1 = subset(bc, group == "Good")

# to conform to jmpost
g1$cov = 0 #empty covariate
g1$arm = "Arm"
g1$study = "Study"
g1$ID =  as.character(1:nrow(g1)) #id as string

# get sim data because we need(?) to provide longitudinal measurements
set.seed(129)
sim_data <- simulate_joint_data(
  n_arm = c(50, 50),
  times = 1:2000,
  lambda_cen = 1 / 9000,
  beta_cat = c(
    "A" = 0,
    "B" = -0.1,
    "C" = 0.5
  ),
  beta_cont = 0.3,
  lm_fun = sim_lm_random_slope(
    intercept = 30,
    slope_mu = c(1, 2),
    slope_sigma = 0.2,
    sigma = 20,
    phi = 0.1
  ),
  os_fun = sim_os_weibull(
    lambda = 1 / 300,
    gamma = 0.97
  )
)
long_data <- sim_data$lm |>
  dplyr::filter(time %in% c(1, 50, 100, 150, 200, 250, 300)) |>
  dplyr::arrange(time, pt)

# change the id names to conform with the choise i made before
# one random measurements per person
long_data = head(long_data, nrow(g1))
long_data$pt = g1$ID

# joint the data
loaded_data_surv <- DataJoint(
  survival = DataSurvival(
    data = g1,
    formula = Surv(recyrs, censrec) ~ cov,
    subject = "ID",
    arm = "arm",
    study = "study"
  ),
  longitudinal = DataLongitudinal(
    data = long_data,
    formula = sld ~ time,
    subject = "pt",
    threshold = 5
  )
)

surv_model_weibull <- JointModel(
  survival = SurvivalWeibullPH(
    lambda = prior_lognormal(2,1,init=8),
    gamma = prior_gamma(0.5 ,1, 1.5),
    beta = prior_normal(0,0.01,init = 0)
  )
)

mcmc_surv <- sampleStanModel(
  surv_model_weibull,
  data = loaded_data_surv,
  iter_sampling = 1000,
  iter_warmup = 500,
  chains = 1,
  parallel_chains = 1,
  exe_file = file.path("local", "full")
)
mcmc_surv@results

I get

> mcmc_surv@results
             variable    mean  median   sd  mad      q5     q95 rhat ess_bulk ess_tail
 lp__                 -205.56 -205.29 1.14 0.93 -207.64 -204.34 1.00      434      525
 beta_os_cov[1]          0.00    0.00 0.01 0.01   -0.01    0.02 1.00      577      529
 sm_weibull_ph_lambda    0.04    0.04 0.01 0.01    0.02    0.06 1.00      407      450
 sm_weibull_ph_gamma     1.37    1.36 0.16 0.16    1.14    1.64 1.00      401      383

comparing to the regression of the package flexsurv the parameters seem to a bit off. note the shape parameter is inverted.

est_1 <- flexsurvreg(formula = Surv(recyrs, censrec) ~ 1, data = g1,
            dist = "weibull")
est_1

gives

> est_1
Call:
flexsurvreg(formula = Surv(recyrs, censrec) ~ 1, data = g1, dist = "weibull")

Estimates: 
       est     L95%    U95%    se    
shape   1.687   1.330   2.140   0.205
scale   9.643   7.586  12.258   1.181

N = 229,  Events: 51,  Censored: 178
Total time at risk: 844.5973
Log-likelihood = -186.6605, df = 2
AIC = 377.321

and

est_2 <- flexsurvreg(formula = Surv(recyrs, censrec) ~ cov , data = g1,
            dist = "weibull")
est_2

gives

> est
Call:
flexsurvreg(formula = Surv(recyrs, censrec) ~ cov, data = g1, 
    dist = "weibull")

Estimates: 
       data mean  est   L95%  U95%  se    exp(est)  L95%  U95%
shape    NA       1.69    NA    NA    NA    NA        NA    NA
scale    NA       9.64    NA    NA    NA    NA        NA    NA
cov    0.00       0.00    NA    NA    NA  1.00        NA    NA

N = 229,  Events: 51,  Censored: 178
Total time at risk: 844.5973
Log-likelihood = -186.6605, df = 3
AIC = 379.321

The shape parameter is 1.69 vs 1.37 is off too much. Have i misspecifed something in the jmpost survival model?

graphical comparison: plot_15

gowerc commented 1 year ago

Will take a look, though please note that flexsurv uses the AFT parameterisation of the Weibull distribution which is:

$$ f(x) = \left(\frac{a}{b}\right) \left( \frac{x}{b} \right)^{a-1} exp\left( { -\left( \frac{x}{b} \right)^a}\right) $$

Here we have implemented the PH parameterisation of the Weibull distribution which is:

$$ f(x) = \lambda\gamma x^{\lambda-1} e^{-\lambda x^\gamma} $$

They can be converted between each other as follows:

Parameter AFT PH
Shape $a = \gamma$ $\gamma = a$
Scale $b = \lambda^{-1/\gamma}$ $\lambda = b^{-a}$

(Should note that I don't think the different parameterisations actually have names, I just call them AFT and PH because the former naturally lends itself to the AFT model whilst the latter lends itself to the PH model).

gowerc commented 1 year ago

Heya,

Just to say after making a few small tweaks to the code I got the following:

image

For comparison to flexsurv

# Flexsurv parameters
a = 1.69
b = 9.64

# JMpost parameters
a = gamma = 1.75
b = lambda ^(-1/gamma) =  9.245398

Which seems good enough to me :)

For reference the changes I made were:

ya0 commented 1 year ago

Thank you a lot. This looks nice. I was unsure how to handle the covariate. This change will surely help, Ill try it out tomorrow. When you say just the intercept that is the formula Surv(..., ...) ~ 1 right?

I am aware of the different formulations of the scale parameter. That is why i focued on comparing the scale paramter. But i forgot to change the prior from my formulation which uses the different i.e. the flexsurv formulation.

gowerc commented 1 year ago

Yup intercept only is Surv(...) ~ 1 Let me know if you run into any other issues or need any other help

gowerc commented 1 year ago

@ya0 - Let me know if we can close this or not 😄

ya0 commented 1 year ago

This worked for the bc data for me as well! not sure why the zero covariates interfered with the sampling process.