stan-dev / stanc3

The Stan transpiler (from Stan to C++ and beyond).
BSD 3-Clause "New" or "Revised" License
142 stars 46 forks source link

zero_sum_vector type #1111

Closed bob-carpenter closed 2 months ago

bob-carpenter commented 2 years ago

Is your feature request related to a problem? Please describe.

Suppose we have a logistic regression with say, 5 age groups, 6 income levels, and 2 sexes. If I want to write a regression like this:

y[n] ~ binomial_logit(alpha + beta[age[n]] + gamma[income[n]] + delta[sex[n]]);

then I run into identifiability issues. There are (at least) two ways to deal with this.,

(a) setting beta[1] = 0, gamma[1] = 0, and delta[1] = 0, or (b) setting sum(beta) = 0, sum(gamma) = 0, and sum(delta) = 0.

Approach (a) is more traditional in frequentist applications, but we like (b) more for Bayes because the priors are symmetric rather than defined relative to the first element.

Describe alternatives you've considered

This can be coded in Stan currently as follows.

parameters {
  real alpha;
  vector[4] beta_prefix;
  vector[5] gamma_prefix;
  vector[1] delta_prefix;
}
transformed parameters {
  vector[5] beta = append_row(beta_prefix, -sum(beta_prefix));
  vector[6] gamma = append_row(gamma_prefix, -sum(gamma_prefix));
  vecotr[2] delta = append_row(delta_prefix, -sum(delta_prefix));
}

Note that delta = [delta[1], -delta[1]]', which is how well like to handle binary predictors like this.

Describe the solution you'd like

I'd like to be able to declare the above as

parameters {
  zero_sum_vector[5] beta;
  zero_sum_vector[6] gamma;
  zero_sum_vector[2] delta;
}

and have Stan do the transform under the hood. There is no Jacobian because the transform is a constant linear function, so the C++ for the transform and inverse are trivial.

Considerations for priors

Either way we do this (pinning the first value to 0 or setting last value to negative sum of other values), the meaning of the transformed parameters is a bit subtle. The Stan program

parameters {
  zero_sum_vector[2] delta;
}
model {
  delta ~ normal(0, sigma_delta);
}

is equivalent to unfolding the model as

model {
  delta[1] ~ normal(0, sigma_delta);
  delta[2] ~ normal(0, sigma_delta);
}

which because delta[2] = -delta[1] and a normal with location parameter 0 works out to

parameters {
  real delta1;
}
transformed parameters {
  vector[2] delta = [delta1, -delta1]';
model {
  delta1 ~ normal(0, sigma_delta);
  delta1 ~ normal(0, sigma_delta);
}

which works out to be equivalent to using

model {
  delta[1] ~ normal(0, sigma_delta / sqrt(2));
}

If we instead use delta ~ normal(0, sqrt(2) * sigma_delta()), then we get the equivalent of delta[1] ~ normal(0, sigma_delta), which means sigma_delta is the scale of half the difference between delta[1] and delta[2]. The same effect can be derived more efficiently by just writing this out directly as

delta[1] ~ normal(0, sigma_delta);

If you want to put a prior on the scale of the difference, that can also be done directly without the need for a Jacobian adjustment as

delta[2] - delta[1] ~ normal(0, sigma_delta);

or equivalently,

delta[1] ~ normal(0, 2 * sigma_delta);

The problem with the traditional approach setting alpha[1] = 0 is that we can only put priors on alpha[2:K], the scale of which is determined relative to having set alpha[1] = 0. That means the priors are implicitly about differences between alpha[1] and alpha[k].

spinkney commented 2 years ago

Let's get this to work with multiplier and offset too.

As an aside, if you leave off the intercept and declare the full size of the parameter, the interpretation of the parameters is relative to the mean value of the parameter vector.

WardBrian commented 2 years ago

I’ve let #971 fall out of date a little since it didn’t seem like there was a consensus on merging it, but composing offset/multiplier with other transforms is definitely still a possibility

bob-carpenter commented 2 years ago

Let's get this to work with multiplier and offset too.

I think multipliers would make sense here, but I don't see how offset would help with the sum-to-zero constraint.

spinkney commented 2 years ago

I'm looking at https://mc-stan.org/docs/2_28/stan-users-guide/parameterizing-centered-vectors.html middle of the page:

To generate distributions with marginals other than standard normal, the resulting beta may be scaled by some factor sigma and translated to some new location mu.

bob-carpenter commented 2 years ago

I ran some more experiments to compare the alpha[1] = 0 approach vs. the alpha[N] = -sum(alpha[1:N-1]) approach. In both cases, we put a prior on the constrained to sum-to-zero transformed parameters as would be done if this transform were built in.

Here's Stan code to evaluate the alpha[1] = 0 approach.

parameters {
  vector[1] alpha1;
  vector[3] beta1;
  vector[7] gamma1;
  vector[15] delta1;
}
transformed parameters {
  vector[2] alpha0 = append_row(0, alpha1);
  vector[4] beta0 = append_row(0, beta1);
  vector[8] gamma0 = append_row(0, gamma1);
  vector[16] delta0 = append_row(0, delta1);
}
model {
  alpha0 ~ normal(0, 10);
  beta0 ~ normal(0, 10);
  gamma0 ~ normal(0, 10);
  delta0 ~ normal(0, 10);
}
generated quantities {
  vector[2] alpha = softmax(alpha0);
  vector[4] beta = softmax(beta0);
  vector[8] gamma = softmax(gamma0);
  vector[16] delta = softmax(delta0);
}

Note that the prior is over the transformed parameters (e.g., alpha0), but because the first element is a constant zero, it's equivalent to putting the constraints on the raw parameters (e.g., alpha1). Beyond K=2, you see the asymmetry induced by the prior, with the first element shrinking in magnitude and variance compared to the others.

 variable mean median   sd  mad   q5  q95 rhat ess_bulk ess_tail
alpha[1]  0.50   0.48 0.46 0.71 0.00 1.00 1.00    11070       NA
alpha[2]  0.50   0.52 0.46 0.71 0.00 1.00 1.00    11052       NA

beta[1]   0.13   0.00 0.29 0.00 0.00 0.96 1.00     3541     2829
beta[2]   0.30   0.00 0.42 0.00 0.00 1.00 1.00     6697     2767
beta[4]   0.29   0.00 0.42 0.00 0.00 1.00 1.00     6950     3118

gamma[1]  0.01   0.00 0.07 0.00 0.00 0.01 1.00     2857     3203
gamma[2]  0.14   0.00 0.32 0.00 0.00 1.00 1.00     5835     3053
gamma[8]  0.14   0.00 0.32 0.00 0.00 1.00 1.00     6239     3355

delta[1]  0.00   0.00 0.00 0.00 0.00 0.00 1.00     2824     3220
delta[2]  0.07   0.00 0.23 0.00 0.00 0.82 1.00     6826     2971
delta[16] 0.07   0.00 0.23 0.00 0.00 0.85 1.00     6567     2863

Here's a model to test the sum-to-zero approach. Note again that the prior is on the transformed sum-to-zero parameter here. If you only put a prior on the raw parameters alpha1, etc., you get a bias with the last element growing in magnitude relative to the others.

parameters {
  vector[1] alpha1;
  vector[3] beta1;
  vector[7] gamma1;
  vector[15] delta1;
}
transformed parameters {
  vector[2] alpha0 = append_row(alpha1, -sum(alpha1));
  vector[4] beta0 = append_row(beta1, -sum(beta1));
  vector[8] gamma0 = append_row(gamma1, -sum(gamma1));
  vector[16] delta0 = append_row(delta1, -sum(delta1));
}
model {
  alpha0 ~ normal(0, 10);
  beta0 ~ normal(0, 10);
  gamma0 ~ normal(0, 10);
  delta0 ~ normal(0, 10);
}
generated quantities {
  vector[2] alpha = softmax(alpha0);
  vector[4] beta = softmax(beta0);
  vector[8] gamma = softmax(gamma0);
  vector[16] delta = softmax(delta0);
}

and here's the posterior, which is now symmetric (up to sampling error---if you run a lot longer, the delta values even out).

 variable mean median   sd  mad   q5  q95 rhat ess_bulk ess_tail
 alpha[1]  0.50   0.46 0.47 0.68 0.00 1.00 1.00     7637       NA
 alpha[2]  0.50   0.54 0.47 0.68 0.00 1.00 1.00     7055       NA

 beta[1]   0.25   0.00 0.40 0.00 0.00 1.00 1.00     4345     3099
 beta[2]   0.25   0.00 0.40 0.00 0.00 1.00 1.00     4004     3163
 beta[3]   0.24   0.00 0.40 0.00 0.00 1.00 1.00     5269     2938
 beta[4]   0.26   0.00 0.41 0.00 0.00 1.00 1.00     4374     2671

 gamma[1]  0.12   0.00 0.30 0.00 0.00 1.00 1.00     5036     2965
 gamma[2]  0.13   0.00 0.31 0.00 0.00 0.99 1.00     5467     3234
 gamma[7]  0.12   0.00 0.30 0.00 0.00 1.00 1.00     5039     3208
 gamma[8]  0.13   0.00 0.31 0.00 0.00 1.00 1.00     3649     2988

 delta[1]  0.06   0.00 0.22 0.00 0.00 0.81 1.00     5187     3030
 delta[2]  0.06   0.00 0.21 0.00 0.00 0.59 1.00     5383     3239
 delta[15] 0.06   0.00 0.21 0.00 0.00 0.63 1.00     5929     2901
 delta[16] 0.06   0.00 0.22 0.00 0.00 0.73 1.00     3575     2923

With more draws, the delta[k] posteriors all converge to the same value,

  variable mean median   sd  mad   q5  q95 rhat ess_bulk ess_tail
 delta[1]  0.06   0.00 0.22 0.00 0.00 0.72 1.00    51926    30994
 delta[2]  0.06   0.00 0.22 0.00 0.00 0.73 1.00    50826    31697
 delta[3]  0.06   0.00 0.22 0.00 0.00 0.73 1.00    45405    30223
 delta[4]  0.06   0.00 0.22 0.00 0.00 0.73 1.00    50920    29601
 delta[5]  0.06   0.00 0.22 0.00 0.00 0.73 1.00    46371    31507
 delta[6]  0.06   0.00 0.22 0.00 0.00 0.78 1.00    46401    30598
 delta[7]  0.06   0.00 0.22 0.00 0.00 0.70 1.00    55413    31754
 delta[8]  0.06   0.00 0.22 0.00 0.00 0.74 1.00    49898    32138
 delta[9]  0.06   0.00 0.22 0.00 0.00 0.73 1.00    50115    30080
 delta[10] 0.06   0.00 0.22 0.00 0.00 0.74 1.00    54851    31502
 delta[11] 0.06   0.00 0.22 0.00 0.00 0.74 1.00    45674    31264
 delta[12] 0.06   0.00 0.22 0.00 0.00 0.78 1.00    52809    31116
 delta[13] 0.06   0.00 0.22 0.00 0.00 0.70 1.00    53378    30838
 delta[14] 0.06   0.00 0.22 0.00 0.00 0.74 1.00    47033    31092
 delta[15] 0.06   0.00 0.22 0.00 0.00 0.73 1.00    55972    30316
 delta[16] 0.06   0.00 0.22 0.00 0.00 0.72 1.00    37695    30375

In the sum-to-zero case, if you look at the constrained parameters, you get this, with symmetric values that have slightly lower standard deviation than the normal(0, 10) would imply.

  variable  mean median   sd  mad     q5   q95 rhat ess_bulk ess_tail
 alpha0[1]   0.02   0.04 7.10 7.09 -11.68 11.70 1.00    74212    29813
 alpha0[2]  -0.02  -0.04 7.10 7.09 -11.70 11.68 1.00    74212    29813
 delta0[1]  -0.05  -0.09 9.71 9.78 -16.05 15.91 1.00    70935    30294
 delta0[2]   0.01   0.06 9.65 9.71 -15.94 15.84 1.00    65353    30276
 delta0[3]  -0.07  -0.10 9.72 9.85 -15.98 15.91 1.00    55981    29269
 delta0[4]   0.01   0.04 9.71 9.69 -15.97 16.01 1.00    68359    29221
 delta0[5]   0.07   0.03 9.60 9.51 -15.62 15.88 1.00    55903    29952
 delta0[6]   0.05  -0.01 9.75 9.76 -15.91 16.22 1.00    63920    29850
 delta0[7]   0.04   0.09 9.62 9.68 -15.78 15.78 1.00    72702    30726
 delta0[8]  -0.01  -0.02 9.74 9.62 -15.98 15.98 1.00    63786    29842
 delta0[9]  -0.01   0.01 9.65 9.70 -15.83 15.81 1.00    62672    31029
 delta0[10] -0.03  -0.04 9.71 9.77 -16.05 15.87 1.00    74055    30233
 delta0[11]  0.01   0.02 9.69 9.73 -15.98 15.89 1.00    58518    28982
 delta0[12]  0.03  -0.01 9.80 9.83 -16.09 16.20 1.00    68475    28966
 delta0[13] -0.06  -0.02 9.65 9.63 -15.87 15.82 1.00    70493    30169
 delta0[14] -0.02  -0.09 9.60 9.59 -15.83 15.77 1.00    62308    29435
 delta0[15] -0.01  -0.08 9.68 9.75 -15.83 15.91 1.00    73696    29059
 delta0[16]  0.05   0.07 9.67 9.64 -15.82 15.99 1.00    41921    27950

This is because we have effectively added multiple constraints. If you look at the binary case, it's easy to see that the prior is equivalent to

alpha0[1] ~ normal(0, 10);
alpha0[2] ~ normal(0, 10);

and then because alpha0[2] = -alpha0[1] and the normal is symmetric around 0, this reduces to

alpha1[1] ~ normal(0, 10);
alpha1[1] ~ normal(0, 10);

so we get a double prior on alpha1[1], which both just apply, and as a consequence reduce the standard deviation from 10 to 7.1 as you see above. The reduction in prior scale then tapers off as K grows.

bob-carpenter commented 2 years ago

@spinkney: I think we need to remove that translation bit as I don't think it makes sense here. I hadn't realized someone had worked out the reduction in scale as inv(sqrt(1 - inv(K)))) and put it in the user's guide.

spinkney commented 2 years ago

I tested this using the svd approach and it appears to work as well

U <- list()
d <- list()
j <- 0
for(i in c(2, 4, 8, 16)) {
  j <- j + 1
  Sigma <- contr.sum(i) %*% diag(i - 1) %*% t(contr.sum(i))
  s <- svd(Sigma) # Guarantee symmetry
  U[[j]] <- s$u
  d[[j]] <- abs(zapsmall(s$d)) 
}

mod_test <- cmdstan_model("sum_to_zero_test.stan")

stan_data <- list(U_alpha = U[[1]],
                  d_alpha = d[[1]],
                  U_beta = U[[2]],
                  d_beta = d[[2]],
                  U_gamma = U[[3]],
                  d_gamma = d[[3]],
                  U_delta = U[[4]],
                  d_delta = d[[4]]
                  )

test_out <- mod_test$sample(
  data = stan_data,
  parallel_chains = 4
)

summary_out <- test_out$summary(c("alpha", "beta", "gamma", "delta"))
data {
  matrix[2, 2] U_alpha;
  vector[2] d_alpha;
  matrix[4, 4] U_beta;
  vector[4] d_beta;
  matrix[8, 8] U_gamma;
  vector[8] d_gamma;
  matrix[16, 16] U_delta;
  vector[16] d_delta;
}
parameters {
  vector[1] alpha1;
  vector[3] beta1;
  vector[7] gamma1;
  vector[15] delta1;
}
transformed parameters {
  vector[2] alpha0 = diag_post_multiply(U_alpha, d_alpha) * append_row(alpha1, 0);
  vector[4] beta0 = diag_post_multiply(U_beta, d_beta) * append_row(beta1, 0);
  vector[8] gamma0 = diag_post_multiply(U_gamma, d_gamma) * append_row(gamma1, 0);
  vector[16] delta0 = diag_post_multiply(U_delta, d_delta) * append_row(delta1, 0);
}
model {
  alpha0 ~ normal(0, inv(sqrt(1 - inv(2))));
  beta0 ~ normal(0, inv(sqrt(1 - inv(4))));
  gamma0 ~ normal(0, inv(sqrt(1 - inv(8))));
  delta0 ~ normal(0, inv(sqrt(1 - inv(16))));
}
generated quantities {
  vector[2] alpha = softmax(alpha0);
  vector[4] beta = softmax(beta0);
  vector[8] gamma = softmax(gamma0);
  vector[16] delta = softmax(delta0);
}

where a normal(0, 1) prior is implied on each *0 variable

      variable  mean median   sd  mad    q5  q95 rhat ess_bulk ess_tail
 1:  alpha0[1] -0.01  -0.01 0.98 0.98 -1.61 1.58    1     9447     2933
 2:  alpha0[2]  0.01   0.01 0.98 0.98 -1.58 1.61    1     9447     2933
 3:   beta0[1]  0.00   0.02 0.99 0.96 -1.68 1.61    1     8745     2809
 4:   beta0[2]  0.00  -0.01 1.04 1.02 -1.72 1.71    1     8809     2792
 5:   beta0[3]  0.00   0.00 1.00 1.01 -1.63 1.66    1    10303     3109
 6:   beta0[4]  0.00   0.00 1.01 1.02 -1.62 1.67    1    10336     3066
 7:  gamma0[1]  0.00   0.00 1.01 1.01 -1.65 1.66    1    11362     3020
 8:  gamma0[2] -0.02  -0.02 1.00 0.98 -1.67 1.63    1     9431     3208
 9:  gamma0[3]  0.01   0.02 1.02 1.04 -1.68 1.66    1    10264     2642
10:  gamma0[4]  0.02   0.01 1.01 1.00 -1.65 1.68    1     9298     3192
11:  gamma0[5] -0.01   0.00 1.02 1.04 -1.67 1.69    1    10236     2960
12:  gamma0[6] -0.02  -0.01 0.99 0.96 -1.65 1.65    1     8373     2679
13:  gamma0[7]  0.01   0.01 1.01 0.98 -1.63 1.63    1    11449     3008
14:  gamma0[8]  0.01   0.01 1.03 1.05 -1.70 1.66    1     9177     2932
15:  delta0[1]  0.01   0.04 1.00 1.00 -1.62 1.60    1     9758     3013
16:  delta0[2]  0.01   0.02 0.96 0.96 -1.57 1.56    1    10035     2687
17:  delta0[3] -0.01  -0.01 1.01 1.01 -1.70 1.65    1     8770     2891
18:  delta0[4] -0.02  -0.02 0.99 1.00 -1.66 1.61    1    10356     3081
19:  delta0[5]  0.00   0.01 1.03 1.04 -1.69 1.69    1     9061     3049
20:  delta0[6]  0.01   0.02 0.99 0.98 -1.66 1.64    1     9019     3053
21:  delta0[7]  0.00  -0.01 1.00 0.99 -1.61 1.64    1     9416     2849
22:  delta0[8]  0.01   0.02 1.01 1.02 -1.65 1.66    1    10089     2983
23:  delta0[9]  0.00   0.00 1.01 0.96 -1.66 1.71    1     9815     2581
24: delta0[10]  0.00  -0.03 1.02 1.00 -1.66 1.65    1     8671     2961
25: delta0[11] -0.02  -0.03 1.02 1.04 -1.69 1.69    1     9732     2848
26: delta0[12] -0.01   0.01 1.00 0.99 -1.64 1.63    1    10818     2926
27: delta0[13]  0.01   0.01 0.98 0.98 -1.55 1.65    1    10413     2816
28: delta0[14]  0.00   0.00 1.01 1.00 -1.65 1.66    1     9977     2958
29: delta0[15]  0.02   0.03 1.03 1.02 -1.68 1.70    1    10193     2870
30: delta0[16]  0.00  -0.03 1.02 1.03 -1.66 1.67    1    11694     2749
bob-carpenter commented 3 months ago

I think the general problem with the SVD-based and QR-based approaches is that they imply doing calculations with data/transformed data. This would be painful to do with our current architecture where the constraining/unconstraining transforms are functions that don't take in auxiliary information. @WardBrian said it wouldn't be impossible, but would be a lot of work and complicate the parser.