TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
196 stars 32 forks source link

Improve `with_logabsdet_jacobian` performance for `SimplexBijector` #303

Open torfjelde opened 2 months ago

torfjelde commented 2 months ago

The PR made me wonder whether it would be possible to improve performance of with_logabsdet_jacobian for SimplexBijector by not performing transform and logabsdetjac separately when both are requested. Doesn't block this PR and maybe would lead to more code duplications though.

Originally posted by @devmotion in https://github.com/TuringLang/Bijectors.jl/pull/302#pullrequestreview-2000888436

sethaxen commented 1 month ago

Yes, it's absolutely possible. The simplex transform we use (and Stan as well) is just the classic stick-breaking transform (called the inverse multiplicative log-ratio transform in the compositional data analysis literature) shifted in the unconstrained space so that for all symmetric Dirichlet distributions, the unconstrained distribution's mode is at the origin. Its Jacobian is the product of the elements of the output vector, but for numerical stability, it's better to perform the entire transform on the log-scale. There's a Stan implementation of this here: https://github.com/mjhajharia/transforms/blob/1207723c4c4208116f80204fe35f1631aaa30f6a/transforms/simplex/StickbreakingLogistic.stan#L2-L17 .