paul-buerkner / brms

brms R package for Bayesian generalized multivariate non-linear multilevel models using Stan
https://paul-buerkner.github.io/brms/
GNU General Public License v2.0
1.26k stars 181 forks source link

Use multiple indexing in Stan code for varying effects #772

Open paul-buerkner opened 4 years ago

paul-buerkner commented 4 years ago

Currently, the Stan code of a multilevel models looks a little verbose due to first indexing columns and then looping over observations to select the right elements of the computed vectors. This has historically been more efficient that other indexing options available in Stan. However, with the multiple indexing feature of Stan, there should be some much less verbose option available.

Preliminary analysis suggests that this will actually make the sampling less efficient (see branch 're-multiple-indexing') but more testing is required to say something reliable about the efficiency aspect.

Here is how the Stan code of a varying intercept, varying slope model currently looks:

data {
  int<lower=1> N;  // number of observations
  vector[N] Y;  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  // data for group-level effects of ID 1
  int<lower=1> N_1;  // number of grouping levels
  int<lower=1> M_1;  // number of coefficients per level
  int<lower=1> J_1[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_1_1;
  vector[N] Z_1_2;
  int<lower=1> NC_1;  // number of group-level correlations
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  int Kc = K - 1;
  matrix[N, Kc] Xc;  // centered version of X without an intercept
  vector[Kc] means_X;  // column means of X before centering
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
}
parameters {
  vector[Kc] b;  // population-level effects
  // temporary intercept for centered predictors
  real Intercept;
  real<lower=0> sigma;  // residual SD
  vector<lower=0>[M_1] sd_1;  // group-level standard deviations
  matrix[M_1, N_1] z_1;  // standardized group-level effects
  // cholesky factor of correlation matrix
  cholesky_factor_corr[M_1] L_1;
}
transformed parameters {
  // actual group-level effects
  matrix[N_1, M_1] r_1 = (diag_pre_multiply(sd_1, L_1) * z_1)';
  // using vectors speeds up indexing in loops
  vector[N_1] r_1_1 = r_1[, 1];
  vector[N_1] r_1_2 = r_1[, 2];
}
model {
  // initialize linear predictor term
  vector[N] mu = Intercept + Xc * b;
  for (n in 1:N) {
    // add more terms to the linear predictor
    mu[n] += r_1_1[J_1[n]] * Z_1_1[n] + r_1_2[J_1[n]] * Z_1_2[n];
  }
  ...
}

Here is how the Stan code of a varying intercept, varying slope model could look like

data {
  int<lower=1> N;  // number of observations
  vector[N] Y;  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  // data for group-level effects of ID 1
  int<lower=1> N_1;  // number of grouping levels
  int<lower=1> M_1;  // number of coefficients per level
  int<lower=1> J_1[N];  // grouping indicator per observation
  // group-level predictor values
  matrix[N, M_1] Z_1;
  int<lower=1> NC_1;  // number of group-level correlations
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  int Kc = K - 1;
  matrix[N, Kc] Xc;  // centered version of X without an intercept
  vector[Kc] means_X;  // column means of X before centering
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
}
parameters {
  vector[Kc] b;  // population-level effects
  // temporary intercept for centered predictors
  real Intercept;
  real<lower=0> sigma;  // residual SD
  vector<lower=0>[M_1] sd_1;  // group-level standard deviations
  matrix[M_1, N_1] z_1;  // standardized group-level effects
  // cholesky factor of correlation matrix
  cholesky_factor_corr[M_1] L_1;
}
transformed parameters {
  // actual group-level effects
  matrix[N_1, M_1] r_1 = (diag_pre_multiply(sd_1, L_1) * z_1)';
}
model {
  // initialize linear predictor term
  vector[N] mu = Intercept + Xc * b + rows_dot_product(Z_1, r_1[J_1]);
  ...
}
SteveBronder commented 3 years ago

Just wanted to update, for 2.27 there's a few notable things that could make for a good bit faster brms thats related to this issue. I think we can try these once rstan updates to 2.27

  1. https://github.com/stan-dev/stanc3/pull/865 makes it so that data and transformed data is stored as an Eigen::Map<Eigen::Matrix>. The PR goes over what that means in detail, but the tl;dr is that we should do large data manipulations once in the transformed data block and then we won't need to copy data when calling stan math functions in the model/transformed parameters block. This is only a thing for data and not for parameters

  2. In https://github.com/stan-dev/math/pull/2462 I made a specialization for csr_matrix_time_vector(data, parameters), so sparse may become more efficient.

  3. We have fma() specializations for matrices and vectors now so

  // add more terms to the linear predictor
  mu += fma(r_1_1[J_1], Z_1_1, r_1_2[J_1] .* Z_1_2);

Should be faster than

  for (n in 1:N) {
    // add more terms to the linear predictor
    mu[n] += r_1_1[J_1[n]] * Z_1_1[n] + r_1_2[J_1[n]] * Z_1_2[n];
  }

though we have an optimization in the compiler to take

  // add more terms to the linear predictor
  mu += r_1_1[J_1] .* Z_1_1 + r_1_2[J_1] .* Z_1_2;

and automajically do the fma thing if it can. It's not exposed yet but should be by 2.28 so might be worth just waiting

paul-buerkner commented 3 years ago

That sounds really nice and exciting! Out of interest, did you make some performance benchmarks for actual sparse situation? I wasn't sure if the benchmarks shown in stan-dev/math#2462 are for sparse matrix stuff in particular?

SteveBronder commented 3 years ago

I did not do benchmarking, though I can whip one up this week using brms that should be pretty easy. In my brain I think it should be faster than calling the multi-indexing each time but there will probably be some sort of cost matrix over the size of groups and data I need to think about

wds15 commented 3 years ago

These things can be really weird and I personally stopped trusting my intuition, but rely on brute-force benchmarks only.

It would be really nice to leave for loops behind...

paul-buerkner commented 7 months ago

With rstan being now more up to date, I put this issue here higher on the agenda.