Open soumyasahu opened 3 years ago
Hi @soumyasahu - so from memory I hardcoded the max three biomarkers, since it simplified the code in the .stan
file - which was already overly complicated, since the Stan file must be pre-compiled in rstanarm, and therefore can't be built dynamically for an unbounded number of biomarkers. With a bit of thought, it would have probably been possible to make it unbounded using some smart indexing, but that would have added more complexity in other ways (e.g. tracking all the different indexes).
The other reason for thinking a limit of three was somewhat reasonable, is that stan_jm
assumes an unstructured covariance matrix across the random effects for the biomarker submodels. So if you had 21 biomarkers with random intercept and slope, we would need to estimate a 42x42 covariance matrix, which wouldn't be feasible.
If you plan to have 21 biomarkers, you'll likely need to restrict the covariance matrix, e.g. assume independence between some or all biomarkers (or something similar). In any case that type of model would need to be implemented outside stan_jm
I think, perhaps using the Stan code in the conference paper as a starting point, like you mentioned...
Hi,
I have implemented the model with 6 longitudinal biomarkers with regularized horseshoe prior for the association parameters in the survival model. For now, we have used a completely diagonal variance-covariance matrix for random effects and instead of splines, I have used Weibull distribution for baseline hazard. The problem I am facing is that the chains are running too slow. To understand the issue, I have simplified the code for only one longitudinal biomarker and compared the run time with 'stan_jm'. I found that even for one variable the chains are ridiculously slow compared to the implementations in this package.
I am sharing the code with one biomarker. This is almost the same as the code you shared in the conference paper. If you kindly share some suggestions for faster implementation, it will be very helpful for me.
code <- "
functions {
vector evaluate_eta(matrix X, matrix Z, int[] Z_id, vector beta, matrix bMat) {
int N = rows(X); // num rows in design matrix
int K = rows(beta); // num predictors
int p = cols(Z); // num group level params
vector[N] eta;
eta = X * beta;
for (k in 1:p){
for (n in 1:N){
eta[n] = eta[n] + bMat[k, Z_id[n]] * Z[n, k];
}
}
return eta;
}
}
data {
int <lower=1> y_K; // number of variables in longitudinal submodel + 1 (intercept)
int <lower=1> y_N; // total number of observations for longitudinal variable
int <lower=1> e_K; // number of variables in event submodel + 1 (intercept)
int <lower=0> b_K; // number of random effects in longitudinal submodel
int <lower=1> b_N; // number of subjects
matrix[y_N, y_K] y_X; // fixed effect model matrix in longitudinal submodel
matrix[y_N, b_K] y_Z; // random effect model matrix in longitudinal submodel
int <lower=0> Z_id[y_N];
vector[y_N] y1;
int <lower=1> nrow_y_Xq;
int <lower=1> nrow_e_Xq;
int <lower=0> Nevents;
matrix[nrow_y_Xq, y_K] y_Xq;
matrix[nrow_y_Xq, b_K] y_Zq;
int <lower=0> Zq_id[nrow_y_Xq];
matrix[nrow_y_Xq, e_K] e_Xq;
int <lower=1> Npat_times_qnodes;
vector[nrow_y_Xq] times;
vector[Npat_times_qnodes] qwts;
}
parameters {
vector[y_K] y1_mean; // primitive coefs in long. submodels
vector<lower=0>[y_K] y1_scale; // primitive coefs in long. submodels
vector[y_K] y1_z_beta; // primitive coefs in long. submodels
vector[e_K] e_mean;
vector<lower=0> [e_K] e_scale;
vector[e_K] e_z_beta; // primitive coefs in event submodel (log hazard ratios)
vector[1] a_mean;
vector<lower=0>[1] a_scale;
vector[1] a_z_beta; // primitive assoc params (log hazard ratios)
real<lower=0> y1_aux; // unscaled residual error SDs
real<lower=0> alpha;
real<lower=0> sigma;
// group level params
vector<lower=0>[b_K] b_sd1; // group level sds
//vector<lower=0>[b_K] b_sd2;
matrix[b_K,b_N] z_b_mat1; // unscaled group level params
cholesky_factor_corr[b_K] b_cholesky1; // cholesky factor of corr matrix
}
transformed parameters {
vector[y_K] y1_beta; // primitive coefs in long. submodels
vector[e_K] e_beta; // primitive coefs in event submodel (log hazard ratios)
vector[1] a_beta; // primitive assoc params (log hazard ratios)
matrix[b_K,b_N] b_mat1;
// coefs for long. submodels
y1_beta = y1_z_beta .* y1_scale + y1_mean;
// coefs for event submodel (incl. association parameters)
e_beta = e_z_beta .* e_scale + e_mean;
a_beta = a_z_beta .* a_scale + a_mean;
// residual error SDs for long. submodels
// group level params
b_mat1 = (diag_pre_multiply(b_sd1, b_cholesky1) * z_b_mat1);
}
model {
//---- Log-lik for longitudinal submodels
{
// declare linear predictors
vector[y_N] y1_eta;
// evaluate linear predictor for each long. submodel
y1_eta = evaluate_eta(y_X, y_Z, Z_id, y1_beta, b_mat1);
// increment the target with the log-lik
target += normal_lpdf(y1 | y1_eta, y1_aux);
}
//----- Log-lik for event submodel (Gauss-Kronrod quadrature)
{
vector[nrow_y_Xq] y1_eta_q;
vector[nrow_y_Xq] e_eta_q;
vector[nrow_y_Xq] log_basehaz;
vector[nrow_y_Xq] log_haz_q;
vector[Nevents] log_haz_etimes;
vector[Npat_times_qnodes] log_haz_qtimes;
// Step 1: event submodel linear predictor at event time and quadrature points
e_eta_q = e_Xq * e_beta;
// Step 2: long. submodel linear predictor at event time and quadrature points
y1_eta_q = evaluate_eta(y_Xq, y_Zq, Zq_id, y1_beta, b_mat1);
// Step 2 (continued): add on contribution from association structure to
// the event submodel linear predictor at event time and quadrature points
e_eta_q = e_eta_q + a_beta[1] * y1_eta_q;
// Step 3: log baseline hazard at event time and quadrature points
for (i in 1:nrow_y_Xq){
log_basehaz[i] = weibull_lpdf(times[i] | alpha, sigma);
}
// Step 4: log hazard at event time and quadrature points
log_haz_q = log_basehaz + e_eta_q;
// Step 5: log hazard at event times only
// (i.e. log hazard contribution to the likelihood)
log_haz_etimes = head(log_haz_q, Nevents);
// Step 6: log hazard at quadrature points only
log_haz_qtimes = tail(log_haz_q, Npat_times_qnodes);
// Step 7: log likelihood for event submodel
target += sum(log_haz_etimes) - dot_product(qwts, exp(log_haz_qtimes));
}
//----- Log-priors
// coefficients for long. submodels
target += normal_lpdf(y1_z_beta | 0, 1);
target += normal_lpdf(y1_mean | 0, 1);
target += inv_gamma_lpdf(rows_dot_self(y1_scale) | .5, .5);
// coefficients for event submodel
target += normal_lpdf(e_z_beta | 0, 1);
target += normal_lpdf(e_mean | 0, 1);
target += inv_gamma_lpdf(rows_dot_self(e_scale) | .5, .5);
target += normal_lpdf(a_z_beta | 0, 1);
target += normal_lpdf(a_mean | 0, 1);
target += inv_gamma_lpdf(rows_dot_self(a_scale) | .5, .5);
// residual error SDs for long. submodels
target += inv_gamma_lpdf(y1_aux*y1_aux | .5, .5);
// Weibull distribution for baseline hazard
target += inv_gamma_lpdf(alpha | .5, .5);
target += inv_gamma_lpdf(sigma | .5, .5);
// group level terms
// sds
target += inv_gamma_lpdf(rows_dot_self(b_sd1) | .5, .5);
// primitive coefs
target += normal_lpdf(to_vector(z_b_mat1) | 0, 1);
// corr matrix
target += lkj_corr_cholesky_lpdf(b_cholesky1 | 1);
}
"
model <- stan(model_code=code,
data = list(y_K = ncol(X),
y_N = nrow(X),
e_K = ncol(X_surv),
b_K = ncol(Z),
b_N = nrow(X_surv),
y_X = X,
y_Z = Z,
Z_id = Z_id,
y1 = y1,
nrow_y_Xq = (nrow(X_q_conv) + nrow(X_q)),
nrow_e_Xq = (nrow(e_X_q_conv) + nrow(e_X_q)),
Nevents = nrow(e_X_q_conv),
y_Xq = rbind(X_q_conv, X_q),
y_Zq = rbind(Z_q_conv, Z_q),
Zq_id = c(Z_q_id_conv, Z_q_id),
e_Xq = rbind(e_X_q_conv, e_X_q),
Npat_times_qnodes = nrow(X_q),
times = c(SurvData$surv_time[converted_id], all_times),
qwts = all_weights),
#init = initf,
chains = 1,
iter = 1000,
warmup = 500,
control=list(max_treedepth=15))
For an update, I have removed the second 'for' loop in 'evaluate_eta' function which has increased the speed but still it is much slower compared to your package. For details, 1000 (500 warmup + 500 sampling) iterations took 4600 sec (4400 for warmup and 200 for sampling) for the package while the same took 16000 sec (8000 for warmup and sampling each).
I accept that I am a beginner in stan coding, also I don't have a detailed understanding of what's going on inside the stan computation. I wonder about the fact that stan shows 1000 transitions using 10 leapfrog steps per transition would take 110 seconds for my code but 180 seconds for the package. I am also observing the fact that sampling is really very fast in the package. Can you kindly give me any suggestions for that?
Hi @soumyasahu - sorry for the slow reply. To be honest, I can't remember each of the optimisations we made between that conference paper and the latest version, but it was quite a few changes along the way! And I doubt I'll find time to look into the details of the code to make specific suggestions.
Having said that, one thing to consider is priors. The rstanarm version probably does a bunch of autoscaling of priors, and that might not have been happening in the conference paper code? If not, then it would be worth thinking about the priors - are they too vague / flat in your example, and if so, can you constrain them more to improve sampling?
Also, for optimisations in the stan code itself, we'd have to compare to the jm.stan model and the stan code blocks it #include
s. Have you tried stepping through the jm.stan file? https://github.com/stan-dev/rstanarm/blob/master/src/stan_files/jm.stan. I realise it is very dense for someone not familiar with the code, but perhaps it help enlighten us. Although, no guarantees it is something in the .stan file - it's possible the speed gain is to do with the preprocessing of the data, e.g. the prior autoscaling I mentioned...
Hi @sambrilleman, thanks for the reply. I looked at the stan files, they are mostly for implementation of splines for baseline hazard and the shared random and fixed effect modeling. I am here doing a simpler model involving only mean trajectories in the survival model and using Weibull baseline hazard.
I have noticed that although the runtime is slower for my code the inference is pretty close to that of stan_jm.
I looked at the autoscaling of the priors, they are done based on the scale of the particular variable. I also noticed one thing that the variables are centered, I am not sure if centering is something to do regarding the speed. I request some suggestions in following two things: (1) In the package the sampling iterations are ridiculously faster than the warm-up iterations (4400 for 500 warmups but 200 for 500 sampling). Is there any way to implement that? This looks really surprising and interesting to me. (2) First, how the rescaling is done here? Can you give me an idea? In the code, any coefficient beta has been written as, beta = beta_z * beta_scale + beta_mean, where beta_z ~ Normal (0, 1), so we need to put priors on beta_scale and beta_mean. But when I am looking at the prior_sumary from the fitted model by stan_jm, they are showing beta ~ normal (location = 0, scale = x), this is confusing for setting a prior for each beta_scale and beta_mean. Also, what kind of priors you have used for scale parameters?
I understand that it is very difficult to find time to check the code but if you answer these two, it will be a great help for me.
I also noticed one thing that the variables are centered, I am not sure if centering is something to do regarding the speed.
It may well be. The centering of the variables can influence the estimate of the baseline hazard, and in particular the value for the intercept in the regression (which from memory is the log scale of the weibull I think?), so centering might impact sampling in that way.
There might also be something about the need for centering of the covariate that corresponds to the expected value of the biomarker. If you aren't centering that in the survival model, then the expected value of the survival model can be sensitive to changes in the expected value in the biomarker, and the latter can fluctuate a lot during sampling if the priors of the regression coefficients in your biomarker model aren't tightly constrained. So centering the expected value of the biomarker might reduce that sensitivity a bit. Hopefully what I am saying makes sense - it's been a while (a couple of years!) since I've been working with joint models.
(1) In the package the sampling iterations are ridiculously faster than the warm-up iterations (4400 for 500 warmups but 200 for 500 sampling). Is there any way to implement that? This looks really surprising and interesting to me.
This must just be because in the stan_jm implementation the warmup manages to find good settings for sampling -- e.g. an appropriate step length (or whatever it is called) for the sampler. Whereas in your other parameterisation it must struggle to ever find the appropriate step length. Which I guess suggests that some reparameterisation is needed to make the joint posterior easier to explore - which I guess it what we are trying to find here with discussing the centering / scaling / etc!
(2) First, how the rescaling is done here? Can you give me an idea? In the code, any coefficient beta has been written as, beta = beta_z * beta_scale + beta_mean, where beta_z ~ Normal (0, 1), so we need to put priors on beta_scale and beta_mean.
I think this is just the non-centered parameterisation: https://mc-stan.org/docs/2_18/stan-users-guide/reparameterization-section.html. beta_scale
and beta_mean
are just constants I think (i.e. the regression coefs just have normal priors with mean beta_mean and sd beta_scale) - see the priorLong and priorEvent argument in the docs: http://mc-stan.org/rstanarm/reference/stan_jm.html.
But when I am looking at the prior_sumary from the fitted model by stan_jm, they are showing beta ~ normal (location = 0, scale = x), this is confusing for setting a prior for each beta_scale and beta_mean.
I think this just implies that beta_mean = 0
in your previous question. Which would make sense - the regression coefs are almost certainly drawn from a symmetric (i.e. normal) prior with mean 0.
Also, what kind of priors you have used for scale parameters?
I'm not exactly sure which scale parameter are you referring to here, but...
The beta_scale
in your question above is just a constant I think. And it is determined using the "autoscaling" logic, which is a little bit complicated, but described somewhere in the docs, e.g. http://mc-stan.org/rstanarm/reference/priors.html.
Alternatively, there is a scale parameter for the biomarker model - this is described in the priorLong_aux argument in the docs: http://mc-stan.org/rstanarm/reference/stan_jm.html. I think there it says it is a half-cauchy...
Hope some of that helps...!
Hi @sambrilleman,
Thanks a lot for your detailed reply. I standardized the biomarkers, rescaled the priors, and most importantly used fixed values for beta_mean
and beta_scale
instead of assigning a prior to them. Now, my code matches the speed of the package. Now, I am going ahead to accomodating all 21 biomarkers with a shrinkage prior, for simplicity, I am using a completely diagonal covariance matrix.
Thanks again, it is a great help.
Description:
In our problem, we have 21 longitudinal biomarkers with one survival outcome. Our goal is to find out which of the biomarkers are significantly associated with the survival outcome. I have used this package with one longitudinal biomarker at a time and found out that not only the mean trajectory but also its slope is significantly affecting the survival outcome for some of the biomarkers. My next plan is to use all the longitudinal biomarkers together to use their association. For variable selection, we can use some standard shrinkage prior for the 'assoc' and the 'assoc_d' parameters.
In this package, I have noticed that the user is not allowed to use more than 3 longitudinal variables. I want to know what can be the issue with using more than 3 longitudinal biomarkers. You have shared the main part of the stan code in a conference paper. My plan is to use that to implement my idea. Do you have any suggestions regarding this?