paul-buerkner / brms

brms R package for Bayesian generalized multivariate non-linear multilevel models using Stan
https://paul-buerkner.github.io/brms/
GNU General Public License v2.0
1.27k stars 183 forks source link

Matrix normal sampling to avoid direct computation of Kronecker products #1185

Closed Jordan-Scott-Martin closed 7 months ago

Jordan-Scott-Martin commented 3 years ago

Hello Paul,

Currently, when implementing correlated random effects with structured covariance, e.g. in a bivariate phylogenetic model

m1= bf(trait1 ~ 1 + (1|G| gr(phylo, cov = A)) m2= bf(trait2 ~ 1 + (1|G| gr(phylo, cov = A))

brms directly computes the Kronecker product AG . Recently, you implemented a slightly more efficient computation to avoid very small off-diagonal elements (#977), but the computational time for dense matrices without many zeroes remain quite cumbersome.

However, a matrix normal parameterization can instead be used to estimate the scaled matrices without directly computing the Kronecker product. In particular, an n x p matrix P of random effects where

vec(P) ~ MVNormal(vec(0), AG) ≡ P ~ Matrix Normal(0, A,G)

can be sampled from the matrix normal distribution such that

P = 0 + LA ZP LGT

where LA is the lower-triangular Cholesky of A, ZP is a matrix of standard normals, and LGT is the upper-triangular Cholesky of G, as explained in the Wikipedia page.

This matrix normal parameterization seems to provide much more efficient estimation in Stan

transformed data { cholesky_factor_corr[n] L_A = cholesky_decompose(A); //scaling matrix }

parameters { matrix[n, p] z_phylo; //standard normal deviates vector<lower=0>[p] sd_phylo; //standard deviations cholesky_factor_corr[p] cor_phylo; //correlations }

transformed parameters { matrix[n, p] u_phylo; //scaled random effects //matrix normal parameterization of Kronecker product between G and A u_phylo = L_A * z_phylo * diag_pre_multiply(sd_phylo, cor_phylo)' ; }

Would it be possible to implement this in brms? Thanks for your time and consideration.

paul-buerkner commented 3 years ago

I am happy to implement that, thank you! Just to make sure it works as intendent, could you demonstrate the equivalence for an example, perhaps with a simple model from the https://paul-buerkner.github.io/brms/articles/brms_phylogenetics.html vignette?

Jordan-Scott-Martin commented 3 years ago

Great, thanks for the quick response! I didn't see a phylogenetic correlation in the vignette for straightforward comparison, so I've attached a simple simulation showing accurate recovery of the random effect variances and correlation (matrix-normal-example.pdf). With n = 300, a bivariate Gaussian model without fixed effects runs in ~ 1-2 min on my Windows 10 laptop. Happy to write up another example with the phylo vignette data if I missed something obvious.

paul-buerkner commented 3 years ago

The simple one is fine, thank you!

paul-buerkner commented 3 years ago

What I meant was a comparison between the current brms version and the change you proposed to demonstrate equivalence empirically and show differences in estimation efficiency. Is that already in the example? I didn't see it at least, but perhaps I did overlook it.

Jordan-Scott-Martin commented 3 years ago

Ah of course, good suggestion. Here are the results (matrix-normal-brms-comparison.pdf) of 200 simulated datasets (n = 100) with a bivariate phylogenetic correlation (r = 0.5) estimated using current development brms (2.15.9) code and updated brms code with the matrix normal parameterization. The sampling distributions largely overlap, though the current brms code tends slightly more toward downwardly biased estimates of the sample correlation. Already at n = 100, the matrix normal model finishes in approximately half the time of the current code.

Code for the simulation is on pg 2. The only difference in the models is a change from

transformed parameters { matrix[N_1, M_1] r_1; // actual group-level effects // using vectors speeds up indexing in loops vector[N_1] r_1_t1_1; vector[N_1] r_1_t2_2; // compute actual group-level effects r_1 = scale_r_cor_cov(z_1, sd_1, L_1, Lcov_1); r_1_t1_1 = r_1[, 1]; r_1_t2_2 = r_1[, 2]; }

to

transformed parameters { matrix[N_1, M_1] r_1; // actual group-level effects // using vectors speeds up indexing in loops vector[N_1] r_1_t1_1; vector[N_1] r_1_t2_2; // compute actual group-level effects r_1 = Lcov_1 * z_1' * diag_pre_multiply(sd_1, L_1)'; r_1_t1_1 = r_1[, 1]; r_1_t2_2 = r_1[, 2]; }

where z_1' is used to make the standard normal matrix (N_1 x M_1).

paul-buerkner commented 3 years ago

Thanks! I will make that change.

paul-buerkner commented 3 years ago

Do you have a suggestion for how to adjust the corresponding by version where group levels may have different random effects correlation matrices?

 /* compute correlated group-level effects with 'by' variables
  * in the presence of a within-group covariance matrix
  * Args: 
  *   z: matrix of unscaled group-level effects
  *   SD: matrix of standard deviation parameters
  *   L: an array of cholesky factor correlation matrices
  *   Jby: index which grouping level belongs to which by level
  *   Lcov: cholesky factor of within-group correlation matrix
  * Returns: 
  *   matrix of scaled group-level effects
  */ 
  matrix scale_r_cor_by_cov(matrix z, matrix SD, matrix[] L, 
                            int[] Jby, matrix Lcov) {
    vector[num_elements(z)] z_flat = to_vector(z);
    vector[num_elements(z)] r = rep_vector(0, num_elements(z));
    matrix[rows(L[1]), cols(L[1])] LC[size(L)];
    int rows_z = rows(z);
    int rows_L = rows(L[1]);
    for (i in 1:size(LC)) {
      LC[i] = diag_pre_multiply(SD[, i], L[i]);
    }
    // kronecker product of cholesky factors times a vector
    for (icov in 1:rows(Lcov)) {
      for (jcov in 1:icov) {
        if (Lcov[icov, jcov] > 1e-10) { 
          // avoid calculating products between unrelated individuals
          for (i in 1:rows_L) {
            for (j in 1:i) {
              // incremented element of the output vector
              int k = (rows_L * (icov - 1)) + i;
              // applied element of the input vector
              int l = (rows_L * (jcov - 1)) + j;
              // column number of z to which z_flat[l] belongs
              int m = (l - 1) / rows_z + 1;
              r[k] = r[k] + Lcov[icov, jcov] * LC[Jby[m]][i, j] * z_flat[l];
            }
          }
        }
      }
    }
    // r is returned in another dimension order than z
    return to_matrix(r, cols(z), rows(z), 0);
  }

This was an immediate generalization of the now to be changed scale_r_cor_cov function.

Jordan-Scott-Martin commented 3 years ago

Apologies for my delayed response, Paul. A straightforward solution for the scale_r_cor_cov function is to calculate multiple scaled matrices for each unique group, and to then build the returned r matrix with the appropriate row from the appropriate group matrix per observation. Here is my updated function

functions { //... matrix scale_r_cor_by_cov(matrix z, matrix SD, matrix[] L, int[] Jby, matrix Lcov) { matrix[rows(L[1]), cols(L[1])] LC[size(L)]; matrix[cols(z), rows(z)] r_by[size(LC)]; matrix[cols(z),rows(z)] r;

for (i in 1:size(LC)) { //create VCVs for each group LC[i] = diag_pre_multiply(SD[, i], L[i]);

//create scaled effects for each group VCV r_by[i] = Lcov * z' * LC[i]'; }

for (j in 1:cols(z)) { //select appropriate effects based on group r[j] = r_by[Jby[j]][j]; } return r; } }

The updated function continues to run about +2x faster at small sample size (2 groups, N = 50/group). Results for a single random dataset are shown in the attached file with the full code (matrix-normal-brms-by-function.pdf). I can also run the simulation for more iterations if desirable, though I'd assume the same sampling properties will be observed with and without group-specific VCVs.

paul-buerkner commented 3 years ago

Thank you so much! The posterior medians look quite different though. Is this really still the same model just with a different (more efficient) parameterization?

Jordan-Scott-Martin commented 3 years ago

Apologies for the slow response, this took me awhile to figure out! My previous suggestion for updating the scale_r_cor_cov function was biased due to scaling the group-specific lower Cholesky VCVs with matrices including all groups, e.g. LcovLC[1] and LcovLC[2] with the known covariance matrix Lcov across all groups, which seems to have been attentuating the group-specific VCVs LC[1] and LC[2] toward one another.

I've avoided this issue in the updated function below by first using the matrix normal parameterization to scale the independent random effects in z with the known covariance matrix Lcov, such that r_by=LcovLcovTI where I is an M_1 x M_1 identity matrix. I then scale each row of this N_1 x M_1 matrix with the appropriate group-specific VCV in LC, and the random effects are returned together in the proper order of the r matrix.

functions { //... matrix scale_r_cor_by_cov(matrix z, matrix SD, matrix[] L, int[] Jby, matrix Lcov) { matrix[rows(L[1]), cols(L[1])] LC[size(L)]; matrix[cols(z), rows(z)] r_by[size(LC)]; matrix[cols(z),rows(z)] r;

for (i in 1:size(LC)) { //create VCVs for each group LC[i] = diag_pre_multiply(SD[, i], L[i]);

//scale effects with known covariance matrix r_by[i] = Lcov * z' * diag_matrix(rep_vector(1,rows(z))); }

for (j in 1:cols(z)) { //scale effects with group-specific VCV r[j] = r_by[Jby[j]][j] * LC[Jby[j]]; } return r; } }

As before, I've run a short simulation (200 datasets, 2 groups with N = 50 each) to compare the time and estimation bias of the current brms implementation and the proposed matrix normal for estimating two group-specific phylogenetic correlations (r = -0.5 and 0.5 respectively). As expected, the sampling distributions largely overlap and the matrix normal completes in about half the time.

results

Here is R code for replicating the simulation. Perhaps, given that estimates may differ slightly at the sample level, it would be useful to consider the matrix normal as an optional, alternative implementation for efficiency gains, similar to decomp=QR?

paul-buerkner commented 3 years ago

Thank you! This look much better. What I don't understand still, is whether this mathematically implies the same model? I was under the impression that it does but your last comment about perhaps making it optional confuses me a bit. What is it that makes you suspect that the models are not the same and that differences are not only because of sampling inaccuracy?

Jordan-Scott-Martin commented 3 years ago

My understanding is that the implementation is mathematically equivalent, following the steps in my first post based on the details of the Wikipedia page. I've also found in recent work that the matrix normal approach provides unbiased estimates in Stan for quantitative genetic models. However, it's less clear to me why the current and proposed brms implementations aren't computationally equivalent in Stan, at least for individual samples. Perhaps this is vaguely similar to differences in the estimation bias of centered vs non-centered random effects in Stan, but I'm not really sure, as both models seem to be converging well. So I suggested treating the matrix normal an optional feature because it may otherwise lead to meaningful differences in point estimates between models estimated with newer and older versions of brms.

I'm happy to run more simulations if you have suggestions for how to best investigate the question further.

paul-buerkner commented 3 years ago

Thank you for your time and effort! Perhaps, there is a mistake in the current brms implementation or we messed up the indexing during transformation form z to r in either of the implementation? I don't really know but am hesitant to implement the new version before I understand where these differences are coming from. Does the same problem apply already without the by stuff? If yes running and comparing the examples form the brms_phylogenetics vignette could help in our understanding. or did you check them already?

paul-buerkner commented 7 months ago

I will close this issue for now to clean up the issue tracker a bit since this isn't high priority and the discussion seems stale.