mjhajharia / transforms

2 stars 1 forks source link

does the stick breaking Jacobian need 0.5 * log(N)? #68

Closed spinkney closed 10 months ago

spinkney commented 11 months ago

When I AD through the stick breaking transform in Julia I need to add 0.5 * log(N) to get the same log-det-jacobian.

import ForwardDiff
using LinearAlgebra
using LogExpFunctions

function stick_break(x)
    Nm1 = size(x)[1]
    z = logistic.(x .- log.(LinRange(Nm1, 1, Nm1))) 
    cum_sum = 0

    for n in 2:Nm1 
        cum_sum += z[n - 1]
        z[n] *= (1 - cum_sum) 
    end

    return [z ; 1 - (cum_sum + z[Nm1])]
end

x = [1., 2., -0.4]
J = ForwardDiff.jacobian(x -> stick_break(x), x)
logdet(J' * J) * 0.5 # -6.74405991829611

# what we have
N = size(x)[1] + 1
z = logistic.(x .- log.(LinRange(size(x)[1], 1, size(x)[1])))
our_jac_det = sum(log.(z) + log1p.(-z) + log1p.(-cumsum([0 ; s[1:N - 2]]))) # -7.437207098856055

# add in 0.5 * log(N)
out_jac_det + 0.5 * log(N) # -6.744059918296109
sethaxen commented 11 months ago

For simplex, you should use the square Jacobian $J_s$, not $\sqrt{J^\top J}$. The base measure is the Lebesgue measure on the first $N-1$ elements, so simply drop the last element when computing the Jacobian.

The two Jacobians are related by

$$J = \begin{bmatrix} I{N-1} & -1{N-1} \end{bmatrix}^\top J_s,$$

which is why

$$J^\top J = Js^\top (I{N-1} + 1{N-1}1{N-1}^\top) J_s$$

$$\sqrt{|J^\top J|} = |Js| \sqrt{|I{N-1} + 1_{N-1,N-1}|} = |J_s| \sqrt{1 + N-1} = |J_s| \sqrt{N}.$$

The trick $J^\top J$ works at least when $J$ is a linear transformation of the square Jacobian of interest, but that still might inject constant factors into the Jacobian determinant.

spinkney commented 11 months ago

That's very clear, thanks!