mjhajharia / transforms

2 stars 1 forks source link

more stable stick breaking #66

Closed spinkney closed 1 year ago

spinkney commented 1 year ago

This calculates things on the log-scale and uses a streaming version of log_sum_exp to get the cumulative log_sum_exp. The calculation proceeds by first noticing that

$$ \begin{align} &y = \log ( 1 - \sum \exp(x)) \ &1 - \exp(y) = \sum \exp(x) \ &\log(1 - \exp(y)) = \text{LSE}(x) \end{align} $$

All that needs to be done is making a cumulative LSE. We can do this by keeping track of the running max and update whenever we get a new max for the next element.

I will add the PR next.

 functions {
  vector cumulative_logsumexp (vector x) {
      real running_max = negative_infinity();
      real r = 0;
      int N = num_elements(x);
      vector[N] y;

      for (n in 1:N) {
          if (x[n] <= running_max) {
              r += exp(x[n] - running_max);
          } else {
              r *= exp(running_max - x[n]);
              r += 1.;
              running_max = x[n];
          }
          y[n] = log(r) + running_max;
      }
      return y;
  }    

  vector break_that_stick_lp(vector stick_slices) {
    int K = num_elements(stick_slices) + 1;
    vector[K] log_pi;
    real logabsjac = 0;
    vector[K - 1] z = stick_slices - log(reverse(linspaced_vector(K - 1, 1, K - 1)));

    log_pi[1:K - 1] = log_inv_logit(z);

    logabsjac += sum(log_pi[1:K - 1]);

    log_pi[K] = 0;
    log_pi[2:K] += cumulative_sum(log1m_inv_logit(z));

    logabsjac += log_pi[K];
    logabsjac += sum(log1m_exp(cumulative_logsumexp(log_pi[1:K - 2])));

    target += logabsjac;

    return log_pi;
  }

}

  data {
  int<lower=0> N;
  vector<lower=0>[N] alpha;
}
parameters {
  vector[N - 1] y;
}
transformed parameters {
  simplex[N] x = exp(break_that_stick_lp(y));
}
model {
}
sethaxen commented 1 year ago

Can you add a little more info about why the proposed changes are necessary? What are y and x here, and how are they related to the stick-breaking transform? Why is a cumulative logsumexp needed? It seems that one could just use y = log1m_exp(log_sum_exp(x)).

spinkney commented 1 year ago

I believe the cumulative sum is necessary for the jacobian. I did 4 adhoc tests of writing the function differently and this one had no warnings/errors and was the fastest at dimensions > 500.

I can post the other stuff I tried tomorrow.

On Mon, Aug 28, 2023, 4:10 PM Seth Axen @.***> wrote:

Can you add a little more info about why the proposed changes are necessary? What are y and x here, and how are they related to the stick-breaking transform? Why is a cumulative logsumexp needed? It seems that one could just use y = log1m_exp(log_sum_exp(x)).

— Reply to this email directly, view it on GitHub https://github.com/mjhajharia/transforms/issues/66#issuecomment-1696341201, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFU3D6LFKMAKQT2D7YJQU5LXXT3L7ANCNFSM6AAAAAA4BVSENQ . You are receiving this because you authored the thread.Message ID: @.***>

bob-carpenter commented 1 year ago

Cool. I like breaking it apart this way. And this is just what the _lp functions were designed for.

Unless the way you did it is a lot faster or more stable, I think the direct approach is easier to understand:

vector cum_log_sum_exp(vector x) {
  if (rows(x) < 1) return x;
  vector[rows(x)] y;
  y[1] = x[1];
  for (n in 2:rows(x))
    y[n] = log_sum_exp(y[n - 1], x[n]);
  return y;
}

If your way is faster or more stable, I think it can be improved in two ways:

  1. Use log1p or even go all the way to log1p_exp() in the second branch for arithmetic stability (probably slower but more stable).
  2. Avoid the negative infinity with a size test up front (skips internal branch on negative infinity and saves reader some mental arithmetic).

You can speed up break_that_stick by caching the value of log(reverse(linspaced_vector(K - 1, 1, K - 1)) in transformed data (until closures, this will require a second argument to pass it in, but it gets passed by constant reference in C++). And how about a more descriptive name like stick_break_simplex_constrain?

sethaxen commented 1 year ago

I believe the cumulative sum is necessary for the jacobian. I did 4 adhoc tests of writing the function differently and this one had no warnings/errors and was the fastest at dimensions > 500. I can post the other stuff I tried tomorrow.

My question in https://github.com/mjhajharia/transforms/issues/66#issuecomment-1696341201 is not about code, it's about the math. In the OP, you write

$$ y = \log(1 - \sum \exp(x)). $$

What do these variables mean? By the notation in the paper, $y$ would be the vector of unconstrained parameters, and $x$ would be the vector of constrained parameters, but our stick-breaking transform is written

$$z_i = xi / \left(1 - \sum{k=1}^{i-1} x_k\right)$$

$$y_i = \mathrm{logit}(z_i) + \log(N - i) = \log(xi) - \log\left(1 - \sum{k=1}^{i} x_k\right) + \log(N - i),$$

which is not the same expression as in the OP. I am asking if you can relate this math to the stick-breaking transform as described in the paper, so I can understand what this PR is doing.

spinkney commented 1 year ago

Taking the transform in https://github.com/mjhajharia/transforms/blob/main/transforms/simplex/Stickbreaking.stan I put everything onto the log scale. It is the exact same transform as the stickbreaking.

@bob-carpenter your cum_log_sum_exp version is faster.

spinkney commented 1 year ago

I can go through and create _lp functions for the transforms. I also think it would be nicer to have these encapsulated into functions.