mjhajharia / transforms

2 stars 1 forks source link

Create probit_product.stan #54

Closed spinkney closed 2 years ago

spinkney commented 2 years ago

I added the probit code that @sethaxen wrote about. @sethaxen please double check that I understood everything correctly.

spinkney commented 2 years ago

That derivation seemingly works for any distribution. I tested student t and exponential and it seems to give uniform simplex and marginals distributed as what I put in.

Exponential

parameters {
  vector<lower=0>[N - 1] y;
}
transformed parameters {
  simplex[N] x;
      real log_det_jacobian = -lgamma(N);
  {
    real log_u, log_z;
    real sum_log_z = 0;
    for (i in 1:(N-1)) {
      log_u = exponential_lcdf(y[i] | 1);
      log_z = log_u / (N - i);
      x[i] = exp(sum_log_z + log1m_exp(log_z));
      sum_log_z += log_z;
      log_det_jacobian += exponential_lpdf(y[i] | 1);
    }
    x[N] = exp(sum_log_z);
  }

and output of

> out$summary(c("x", "y"))
# A tibble: 9 × 10
  variable  mean median    sd   mad     q5   q95  rhat ess_bulk ess_tail
  <chr>    <dbl>  <dbl> <dbl> <dbl>  <dbl> <dbl> <dbl>    <dbl>    <dbl>
1 x[1]     0.195  0.156 0.159 0.152 0.0131 0.514  1.00    3037.    2018.
2 x[2]     0.199  0.154 0.164 0.149 0.0120 0.533  1.00    2642.    2071.
3 x[3]     0.205  0.162 0.166 0.156 0.0143 0.536  1.00    2745.    2168.
4 x[4]     0.202  0.160 0.160 0.151 0.0148 0.526  1.00    2520.    2416.
5 x[5]     0.199  0.162 0.161 0.154 0.0104 0.517  1.00    1953.    1894.
6 y[1]     1.02   0.710 0.983 0.732 0.0575 2.97   1.00    3037.    2018.
7 y[2]     1.02   0.730 1.01  0.734 0.0559 3.05   1.00    2571.    1838.
8 y[3]     0.977  0.705 0.960 0.722 0.0417 2.88   1.00    2661.    1545.
9 y[4]     0.978  0.683 0.991 0.690 0.0441 2.96   1.00    2606.    1395.
sethaxen commented 2 years ago

I haven't tested it, but the implementation looks correct. And yes, it should work generically. If we have a distribution whose support is the entire real line, then its CDF is a bijection, so we can use its CDF to constrain the real line to $[0, 1]$. Since the derivative of the CDF is the PDF, the two give us a new transform. But that excludes the exponential distribution unless you first pull it back through the $x = \exp(y)$ transform.

If we end up including the inverse-probit transform, then it might be worth generalizing the derivation in the paper; otherwise, I don't think this generality is terribly useful, since it just introduces too many choices that are no better motivated than logistic (or probit).

spinkney commented 2 years ago

Maybe it's overkill for getting a simplex parameter but I think this could be very useful for compositional data analysis. Does this generalize to non uniform simplices? Because if you have data on the simplex that you want to analyze it could be simpler to run it through the reverse (our forward transform) to a standard normal, do analysis on that scale, then push the results back to the simplex.

For exponential I first constrained y to positive.

sethaxen commented 2 years ago

Does this generalize to non uniform simplices?

It does in the sense that this is a transform that can be used regardless of the target distribution, but in general if the target distribution is non-uniform on the simplex, then if one pulls it back to the latent space through one of these transforms, one should not expect to have a nice distribution of this form. Namely, I think one would lose independence of the marginals. The easiest way to check is to play with the Dirichlet. That's one of the reasons I'm curious to see how this performs. For $\alpha=1$ it has nice geometry, but I have no idea how it will stack up to very low or high $\alpha$ relative to e.g. stick-breakong.

Btw, here's some Julia code for trying out other generic distributions:

julia> using LogExpFunctions, SpecialFunctions, Distributions, ForwardDiff

julia> function hyperspherical_dist_transform(y, d)
           N = length(y)
           T = float(eltype(y))
           x = similar(y, T, N + 1)
           logJ = sum_logz = zero(T)
           for i in eachindex(y)
               logu = logcdf(d, y[i])
               logz = logu / (N - i + 1)
               x[i] = exp(sum_logz + log1mexp(logz))
               logJ += logpdf(d, y[i])
               sum_logz += logz
           end
           x[N + 1] = exp(sum_logz)
           logJ -= logfactorial(N)
           return x, logJ
       end
hyperspherical_dist_transform (generic function with 1 method)

julia> dist = Cauchy();

julia> N = 9;

julia> y = randn(N);

julia> x, logJ = hyperspherical_dist_transform(y, dist);

julia> J = ForwardDiff.jacobian(x -> dist_transform(x, dist)[1][1:N], y);

julia> logabsdet(J)[1] ≈ logJ
true

julia> all(>(0), x) && sum(x) ≈ 1
true
spinkney commented 2 years ago

Thanks 👍! I want to test a normal copula which will keep the marginals independent only linking them in the lpdf calc.

sethaxen commented 2 years ago

In an attempt at consistency, we should either call this the inverse probit transform or call the one in #55 the logit transform (instead of logistic). If as Bob suggested we define transforms as being from constrained to unconstrained, then hyperspherical-logit and hyperspherical-probit make sense.

spinkney commented 2 years ago

@sethaxen it actually works to get a regression on the latent scale

# Specify a Dirichlet(alpha) distribution for testing.
alpha <- c(1,5,3,4)

# Simulate and plot compositional data.
n <- 100
k <- length(alpha)
x <- matrix(rgamma(n*k, alpha), nrow=n, byrow=TRUE)
Stan model

```stan functions { vector probit_simplex_lp(vector y) { int N = num_elements(y) + 1; vector[N] x; real log_det_jacobian = -lgamma(N); { real log_u, log_z; real sum_log_z = 0; for (i in 1:N - 1) { log_u = std_normal_lcdf(y[i] |); log_z = log_u / (N - i); x[i] = exp(sum_log_z + log1m_exp(log_z)); sum_log_z += log_z; log_det_jacobian += std_normal_lpdf(y[i] |); } x[N] = exp(sum_log_z); } return x; } vector probit_simplex(vector y) { int N = num_elements(y) + 1; vector[N] x; { real log_u, log_z; real sum_log_z = 0; for (i in 1:N - 1) { log_u = std_normal_lcdf(y[i] |); log_z = log_u / (N - i); x[i] = exp(sum_log_z + log1m_exp(log_z)); sum_log_z += log_z; } x[N] = exp(sum_log_z); } return x; } vector reverse_probit_simplex_lp(vector x) { int Nm1 = num_elements(x) - 1; vector[Nm1] y; real log_det_jacobian = 0; { for (i in 1:Nm1) { real x_rev = x[i] / sum(x[i:Nm1 + 1]); real cache = pow(1 - x_rev, Nm1 - i + 1) ; y[i] = inv_Phi( cache ); log_det_jacobian += log(Nm1 - i + 1) - std_normal_lpdf(cache |) + (Nm1 - i) * log1m(x_rev); } } return y; } } data { int N; int J; array[N] vector[J] x; } parameters { vector[J - 1] mu; } transformed parameters { array[N] vector[J - 1] y; for (i in 1:N) { y[i] = reverse_probit_simplex_lp(x[i]); } } model { for (i in 1:N) y[i] - mu ~ std_normal(); } generated quantities { vector[J] mu_simplex = probit_simplex(mu); } ```

Doing the regression on the latent scale and then getting back the induced mu on the simplex which compares to the means of the dirichlet

# A tibble: 4 × 10
  variable        mean median      sd     mad     q5    q95  rhat ess_bulk ess_tail
  <chr>          <dbl>  <dbl>   <dbl>   <dbl>  <dbl>  <dbl> <dbl>    <dbl>    <dbl>
1 mu_simplex[1] 0.0581 0.0577 0.00915 0.00886 0.0442 0.0742  1.00    3230.    2535.
2 mu_simplex[2] 0.362  0.362  0.0292  0.0295  0.315  0.409   1.00    3821.    3205.
3 mu_simplex[3] 0.239  0.238  0.0256  0.0265  0.198  0.282   1.00    3983.    2973.
4 mu_simplex[4] 0.341  0.340  0.0285  0.0284  0.295  0.389   1.00    3927.    2914.

> alpha / sum(alpha)
[1] 0.07692308 0.38461538 0.23076923 0.30769231
spinkney commented 2 years ago

Ok, I'm good with any name. Is it hyperspherical_probit then?

sethaxen commented 2 years ago

@sethaxen it actually works to get a regression on the latent scale

I'll look at this more closely later.

Ok, I'm good with any name. Is it hyperspherical_probit then?

I think yes. I'll update #55 and the paper text accordingly.

spinkney commented 2 years ago

@sethaxen I had to do one update to the logic when i == 1 the line

log_z = i == 1 ?  log_u / (N - i - 1) : log_u / (N - i );

then I get

> out$summary("mu_simplex")
# A tibble: 4 × 10
  variable        mean median      sd     mad     q5    q95  rhat ess_bulk ess_tail
  <chr>          <dbl>  <dbl>   <dbl>   <dbl>  <dbl>  <dbl> <dbl>    <dbl>    <dbl>
1 mu_simplex[1] 0.0770 0.0769 0.00391 0.00390 0.0707 0.0836  1.00    3898.    2506.
2 mu_simplex[2] 0.381  0.381  0.00954 0.00971 0.365  0.397   1.00    4265.    2818.
3 mu_simplex[3] 0.233  0.232  0.00803 0.00783 0.219  0.246   1.00    4098.    2760.
4 mu_simplex[4] 0.310  0.310  0.00858 0.00850 0.296  0.324   1.00    4215.    3243.

which is exactly what you'd expect with alpha = [1, 5, 3, 4].

I believe it is already written correctly in the paper. But, it doesn't give out uniform values on the simplex when I test...

So this is really, really cool! No need to do dirichlet regressions or fiddle with alpha's!!

   vector probit_simplex(vector y) {
   int N = num_elements(y) + 1;
   vector[N] x;
  {
    real log_u, log_z;
    real sum_log_z = 0;
    for (i in 1:N - 1) {
      log_u = normal_lcdf(y[i] | 0, 1);
      log_z = i == 1 ?  log_u / (N - i - 1) : log_u / (N - i );
      x[i] = exp(sum_log_z + log1m_exp(log_z));
      sum_log_z += log_z;
    }
    x[N] = exp(sum_log_z);
  }
  return x;
  }
sethaxen commented 2 years ago

@sethaxen I had to do one update to the logic when i == 1 the line

log_z = i == 1 ?  log_u / (N - i - 1) : log_u / (N - i );

what is the reasoning behind this change?

I believe it is already written correctly in the paper. But, it doesn't give out uniform values on the simplex when I test...

I'll double check the math later tonight, but one thing to note is that the math in the paper and the code in https://github.com/mjhajharia/transforms/pull/54#issuecomment-1232219174 use a different definition of $N$ than the Stan implementations in this repo. When going from the paper to the repo, we currently need to do $N \to N-1$.

sethaxen commented 2 years ago

I double-checked, and I don't see any reason why it should be (N - i - 1). It should be (N-i). I also sampled your probit.stan with Stan with alpha=fill(1, 10), and the ECDFs of the marginals concentrate around the true marginal. tmp

spinkney commented 2 years ago

I double-checked, and I don't see any reason why it should be (N - i - 1). It should be (N-i). I also sampled your probit.stan with Stan with alpha=fill(1, 10), and the ECDFs of the marginals concentrate around the true

Yes, it's definitely right and performs as expected when alpha is a vector of 1s. When alpha is composed of positive reals not necessarily 1 I initially hoped that the above code would recover the expected values from the dirichlet distribution with the given alpha. It gets close in many cases but concentrates to something a bit different. In the example above the first value happens to be lower than expected, even with more data the value continues to concentrate around 0.054 rather than 0.077. This happens for other choices of alpha as well. I suspect it's a distortion of the normality assumed and that there won't be a perfect mapping from Dirichlet distributed values to logit/probit normal.

sethaxen commented 2 years ago

Ok, but you're not talking about accuracy of the expectations for sampling from Dirichlet when using this transform, right? In that case I see the following results:

                    Mean     MCSE  StdDev        5%    50%   95%    N_Eff  N_Eff/s    R_hat

lp__                -2.4  2.7e-02     1.2  -4.8e+00   -2.1  -1.1     2018    20805      1.0
accept_stat__       0.92  1.4e-03   0.098      0.71   0.95   1.0  5.2e+03  5.3e+04  1.0e+00
stepsize__          0.79  2.1e-02   0.030      0.75   0.82  0.83  2.0e+00  2.1e+01  8.8e+12
treedepth__          2.3  9.4e-03    0.57       1.0    2.0   3.0  3.6e+03  3.7e+04  1.0e+00
n_leapfrog__         4.7  3.4e-02     2.0       3.0    3.0   7.0  3.7e+03  3.8e+04  1.0e+00
divergent__         0.00      nan    0.00      0.00   0.00  0.00      nan      nan      nan
energy__             3.9  4.2e-02     1.7       1.7    3.6   7.3  1.6e+03  1.7e+04  1.0e+00

y[1]                 1.0  1.1e-02    0.69  -4.1e-02   0.99   2.2     4159    42874     1.00
y[2]               -0.41  7.5e-03    0.46  -1.2e+00  -0.40  0.35     3785    39017     1.00
y[3]                0.20  8.1e-03    0.49  -5.9e-01   0.19   1.0     3747    38625      1.0
x[1]               0.076  1.1e-03   0.070   4.4e-03  0.057  0.21     4247    43783      1.0
x[2]                0.39  2.1e-03    0.13   1.8e-01   0.38  0.61     3794    39114     1.00
x[3]                0.23  1.9e-03    0.11   7.3e-02   0.22  0.44     3619    37309     1.00
x[4]                0.31  2.0e-03    0.13   1.3e-01   0.29  0.54     4044    41693     1.00
log_det_jacobian    -5.6  1.7e-02    0.89  -7.5e+00   -5.4  -4.7     2872    29605      1.0

and the expectations are within the MCSE of the expected means of the marginals: [0.077, 0.385, 0.231, 0.308]

spinkney commented 2 years ago

Oh no, this performs really well in that case!

bob-carpenter commented 2 years ago

Does this generalize to non uniform simplices?

What's nice is that if we have a transform to uniform simplices, then we can just add in another log density to get a non-uniform distribution. Of course, that might not always be the best move in terms of precision or efficiency, but at least it's compositional.

And it's exactly this kind of "we can use any cdf, not just the standard logit cdf" that is going to lead to a combinatorial explosion in possible parameterizations to evaluate. It would be nice if we could say something more generally about tails, etc.