TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
196 stars 32 forks source link

Question on simplex bijector implementation #283

Open Red-Portal opened 11 months ago

Red-Portal commented 11 months ago

Hi,

It appears that torch.probability simply uses softmax for the simplex bijector. Is there a reason our simplex transform is much more complicated? I was also thinking about a GPU-friendly implementation, which the current implementation appears hard do.

torfjelde commented 11 months ago

Softmax isn't bijective. The one we have now is (maps from d to d-1 dimensional)

devmotion commented 11 months ago

See also #51.

Red-Portal commented 11 months ago

Betanalpha does discuss a bijective softmax by arbitrarily setting the endpoint logits. Any experience with this?

devmotion commented 11 months ago

That's supported e.g. in GPLikelihoods (see maybe also the discussion in https://github.com/JuliaGaussianProcesses/GPLikelihoods.jl/issues/55).

Red-Portal commented 11 months ago

Good to know thanks. Though, back to my original intention, I really wish that our simplex bijector could play nicely with GPUs out of the box. Among non-NF bijectors, it seems the simplex bijector is really going to be the big challenge going in that direction. Do we have any plans on how to pursue this? It does seem to me that the softmax approach would be much easier to get this done.

Red-Portal commented 11 months ago

Actually, nevermind. I just wrote a stick-breaking bijector using array operations based on the implementations of numpyro and tensorflow. If this were to be added to Bijectors.jl we'll probably have to add a CUDA array specialization. Let me know how to proceed on this.

devmotion commented 11 months ago

On Julia >= 1.9, a CUDA specialization could be put in an extension (possibly could even just be an extension with GPUArrays).

Red-Portal commented 11 months ago

I do have the feeling that this will have to wait until the batch operation interface is finalized. @torfjelde Do we have an expectation on when that would be?

sethaxen commented 1 month ago

There are three main ways to use softmax for simplex transforms. One uses parameter expansion to retain bijectivity: f(y) = [softmax(y); logsumexp(y)]. The other two come from compositional data analysis literature are called additive log-ratio f(y) = softmax(vcat(y, 0)) and isometric log-ratio f(y) = softmaxx(V * y) for a particular choice of semi-orthogonal matrix V. I'm currently testing performance of each of these versus stick-breaking.