Closed spinkney closed 1 year ago
Can you add a little more info about why the proposed changes are necessary? What are y
and x
here, and how are they related to the stick-breaking transform? Why is a cumulative logsumexp needed? It seems that one could just use y = log1m_exp(log_sum_exp(x))
.
I believe the cumulative sum is necessary for the jacobian. I did 4 adhoc tests of writing the function differently and this one had no warnings/errors and was the fastest at dimensions > 500.
I can post the other stuff I tried tomorrow.
On Mon, Aug 28, 2023, 4:10 PM Seth Axen @.***> wrote:
Can you add a little more info about why the proposed changes are necessary? What are y and x here, and how are they related to the stick-breaking transform? Why is a cumulative logsumexp needed? It seems that one could just use y = log1m_exp(log_sum_exp(x)).
— Reply to this email directly, view it on GitHub https://github.com/mjhajharia/transforms/issues/66#issuecomment-1696341201, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFU3D6LFKMAKQT2D7YJQU5LXXT3L7ANCNFSM6AAAAAA4BVSENQ . You are receiving this because you authored the thread.Message ID: @.***>
Cool. I like breaking it apart this way. And this is just what the _lp
functions were designed for.
Unless the way you did it is a lot faster or more stable, I think the direct approach is easier to understand:
vector cum_log_sum_exp(vector x) {
if (rows(x) < 1) return x;
vector[rows(x)] y;
y[1] = x[1];
for (n in 2:rows(x))
y[n] = log_sum_exp(y[n - 1], x[n]);
return y;
}
If your way is faster or more stable, I think it can be improved in two ways:
log1p
or even go all the way to log1p_exp()
in the second branch for arithmetic stability (probably slower but more stable).You can speed up break_that_stick
by caching the value of log(reverse(linspaced_vector(K - 1, 1, K - 1))
in transformed data (until closures, this will require a second argument to pass it in, but it gets passed by constant reference in C++). And how about a more descriptive name like stick_break_simplex_constrain
?
I believe the cumulative sum is necessary for the jacobian. I did 4 adhoc tests of writing the function differently and this one had no warnings/errors and was the fastest at dimensions > 500. I can post the other stuff I tried tomorrow.
My question in https://github.com/mjhajharia/transforms/issues/66#issuecomment-1696341201 is not about code, it's about the math. In the OP, you write
$$ y = \log(1 - \sum \exp(x)). $$
What do these variables mean? By the notation in the paper, $y$ would be the vector of unconstrained parameters, and $x$ would be the vector of constrained parameters, but our stick-breaking transform is written
$$z_i = xi / \left(1 - \sum{k=1}^{i-1} x_k\right)$$
$$y_i = \mathrm{logit}(z_i) + \log(N - i) = \log(xi) - \log\left(1 - \sum{k=1}^{i} x_k\right) + \log(N - i),$$
which is not the same expression as in the OP. I am asking if you can relate this math to the stick-breaking transform as described in the paper, so I can understand what this PR is doing.
Taking the transform in https://github.com/mjhajharia/transforms/blob/main/transforms/simplex/Stickbreaking.stan I put everything onto the log scale. It is the exact same transform as the stickbreaking.
@bob-carpenter your cum_log_sum_exp
version is faster.
I can go through and create _lp
functions for the transforms. I also think it would be nicer to have these encapsulated into functions.
This calculates things on the log-scale and uses a streaming version of
log_sum_exp
to get the cumulativelog_sum_exp
. The calculation proceeds by first noticing that$$ \begin{align} &y = \log ( 1 - \sum \exp(x)) \ &1 - \exp(y) = \sum \exp(x) \ &\log(1 - \exp(y)) = \text{LSE}(x) \end{align} $$
All that needs to be done is making a cumulative LSE. We can do this by keeping track of the running max and update whenever we get a new max for the next element.
I will add the PR next.