ajnafa / IPWBayes

6 stars 0 forks source link

Multilevel Generalization for Non-Centered Parameterization #1

Open ajnafa opened 1 year ago

ajnafa commented 1 year ago

In theory, there should be a straightforward generalization to multilevel data structures that yields results similar to generalized estimating equations after integrating over the random intercepts.

However, the example of double-weighting provided in Savitsky and Williams (2022) is done in the context of a centered parameterization (see appendix C in their manuscript) and I am admittedly not entirely sure how to code up an implementation that would be compatible with brms which only supports non-centered parameterizations for the group level effects.

I'll give this some more thought tomorrow but we will probably need to ask the Stan folks about this one. Opening an issue here to keep it on my todo list.

sjwild commented 1 year ago

I think this should work. I modified the Stan script from the repo of your paper with Andrew. Next step for me is to generate some brms code and use brms variable names. Tomorrow's project.

I think a function to build the REs works better here, and is probably easier to use with a brms custom family.

I have not tested this. It may explode. Stan didn't yell at me when I compiled it. I should simulate data and test it. Also a project for tomorrow.

/* Pseudo Bayesian Inverse Probability of Treatment Weighted Estimator
* Author: A. Jordan Nafa; Stan Version 2.30.1; Last Revised 08-29-2022 */

  // To solve this problem, we need function to compute REs in transformed parameters block.
  // It is not difficult at all, and I have overthought it. I am annoyed with myself.
  // We need a function that like the following:
  // real compute_res(vector weights_re, vector tau_z, real tau, int G) {
  //    real check_term = 0.0;
  //    for(g in 1:G){
  //       check_term = check_term + weights_re[g] * tau_z[g] * tau;
  //    }
  //    return check_term;
  //  }
  // We can drop the check term, because an lpdf looks to add one value to the log posterior
  // instead we want a vector of random effects that will be added to the log posterior at a later step

  // New data
  //  int<lower = 1> idx_G[N]; // index sspecifiyig group
  // vector<lower = 0>[G] ipw_re_mu; // Mean of the Population-Level Weights
  // vector<lower = 0>[G] ipw_re_sigma; 
  // possibly priors for tau or for the weights. 
  //
  //

  // New parameters
  // real tau;
  // vector[G] tau_z; 

  // new transformed parameters
  // vector[G] U;

  // new priors
  // u_z ~ std_normal()
  // tau ~ exponential(1)

functions {
  // Weighted Log PDF of the Gaussian Pseudo-Likelihood
  real normal_ipw_lpdf(vector y, vector mu, real sigma, vector w_tilde, int N) {
    real weighted_term;
    weighted_term = 0.00;
    for (n in 1:N) {
      weighted_term = weighted_term + w_tilde[n] * (normal_lpdf(y[n] | mu[n], sigma));
    }
    return weighted_term;
  }

  // new function to compute REs. For use hopefully with brms custom family
  vector compute_res(vector weights_re, vector tau_z, real tau, int G) {
      vector[G] U;
      for(g in 1:G){
         U[g] = weights_re[g] * tau_z[g] * tau;
      }
      return U;
    }

}

data {
  int<lower = 0> N; // Observations
  vector[N] Y; // Outcome Stage Response
  int<lower = 1> K; // Number of Population-Level Effects and Intercept
  matrix[N, K] X; // Design Matrix for the Population-Level Effects

  // Statistics from the Design Stage Model
  vector<lower = 0>[N] ipw_mu; // Mean of the Population-Level Weights
  vector<lower = 0>[N] ipw_sigma; // Scale of the Population-Level Weights

  // Prior on the scale of the weights
  real<lower = 0> sd_prior_shape1;
  real<lower = 0> sd_prior_shape2;
  real<lower = 0> re_prior_shape1;
  real<lower = 0> re_prior_shape2;

  // Grouping variables
  int<lower = 1> G; // Number of groups
  int<lower = 1> idx_G[N]; // index of groups
  vector<lower = 0>[G] ipw_re_mu; // Mean of the Population-Level Weights
  vector<lower = 0>[G] ipw_re_sigma; // Scale of the Population-Level Weights

}

transformed data {
  // Priors on the coefficients and intercept
  real<lower = 0> b_prior_sd = 1.5 * (sd(Y)/sd(X[, 2]));
  real alpha_prior_mu = mean(Y);
  real<lower = 0> alpha_prior_sd = 2 * sd(Y);
  real<lower = 0> sigma_prior = 1/sd(Y);

  int Kc = K - 1; 
  matrix[N, Kc] Xc;  // Centered version of X without an Intercept
  vector[Kc] means_X;  // Column Means of the Uncentered Design Matrix

  // Centering the design matrix
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
}

parameters {
  vector[Kc] b; // Population-Level Effects
  real Intercept; // Population-Level Intercept for the Centered Predictors
  real<lower=0> sigma;  // Dispersion Parameter
  real<lower=0, upper=1> weights_z[N]; // Parameter for the IPT Weights

  // New for random effects
  real<lower = 0> tau; // sd for REs
  vector[G] tau_z; // for non-centered res
  real<lower=0, upper=1> weights_z_re[G]; // not sure if these weights are also bound between 0 and 1

}

transformed parameters {
  // Compute the IPT Weights
  vector[N] w_tilde; // IPT Weights
  vector[G] U; // random effects
  vector[G] weights_re; // Group-level weights
  w_tilde = ipw_mu + ipw_sigma * weights_z[1];
  weights_re = ipw_re_mu + ipw_re_sigma * weights_z_re[1];

  U = compute_res(weights_re, tau_z, tau, G);

}

model {
  // Initialize the Linear Predictor
  vector[N] mu = Intercept + Xc * b + U[idx_G];

  // Sampling the Weights
  weights_z ~ beta(sd_prior_shape1, sd_prior_shape2);

  // Priors for the Model Parameters
  target += normal_lpdf(Intercept | alpha_prior_mu, alpha_prior_sd);
  target += normal_lpdf(b | 0, b_prior_sd);
  target += exponential_lpdf(sigma | sigma_prior);

  // priors for random effect
  target += std_normal_lpdf(tau_z);
  target += exponential_lpdf(tau | 1);
  weights_z_ re ~ beta(re_prior_shape1, re_prior_shape2);

  // Weighted Likelihood
  target += normal_ipw_lpdf(Y | mu, sigma, w_tilde, N);
}

generated quantities {
  // Population-level Intercept on the Original Scale
  real b_Intercept = Intercept - dot_product(means_X, b);
}
ajnafa commented 1 year ago

This looks great @sjwild! I'll see if I can't integrate this into the build functions sometime this week and do some further testing.