TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
199 stars 32 forks source link

Make SimplexBijector actually bijective #263

Closed sethaxen closed 1 year ago

sethaxen commented 1 year ago

Similar to #228, currently the SimplexBijector makes transformed distributions improper. A demo from slack:

julia> using Turing

julia> @model function foo()
           d = Dirichlet(ones(2))
           x ~ filldist(Flat(), length(d))
           Turing.@addlogprob! logpdf(transformed(d), x)
           y = transform(inverse(bijector(d)), x)
           return (; y)
       end;

julia> chns = sample(foo(), NUTS(500, 0.8), MCMCThreads(), 1_000, 4)
┌ Info: Found initial step size
└   ϵ = 3.6
┌ Info: Found initial step size
└   ϵ = 3.6
┌ Info: Found initial step size
└   ϵ = 3.6
┌ Info: Found initial step size
└   ϵ = 12.8
Sampling (4 threads) 100%|█████████████████████████████████████████████████████████████| Time: 0:00:06
Chains MCMC chain (1000×14×4 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 5.58 seconds
Compute duration  = 20.62 seconds
parameters        = x[1], x[2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters                     mean                      std                     mcse    ess_bulk    ess_ta ⋯
      Symbol                  Float64                  Float64                  Float64     Float64     Float ⋯

        x[1]                   0.0359                   1.7472                   0.0299   3450.8940   2517.15 ⋯
        x[2]   -1255806938226913.2500   38456907341159336.0000   12941276533029076.0000      9.0031     15.55 ⋯
                                                                                              3 columns omitted

Quantiles
  parameters                      2.5%                     25.0%                     50.0%                    ⋯
      Symbol                   Float64                   Float64                   Float64                  F ⋯

        x[1]                   -3.4263                   -1.0549                    0.0211                    ⋯
        x[2]   -75987690476946768.0000   -23883185857795536.0000   -10795918685407138.0000   2113099877552726 ⋯

julia> yvals = permutedims(stack(first.(generated_quantities(foo(), chns))), (2, 3, 1));^C

julia> ess(yvals)
2-element Vector{Float64}:
 3450.8940281885793
 3450.8940281885766

julia> dropdims(mean(yvals; dims=(1, 2)); dims=(1, 2))
2-element Vector{Float64}:
 0.5055473060755683
 0.49445269392443175

This PR changes SimplexBijector to transform a K-vector to a K-1-vector. Since the proj type entry in SimplexBijector only impacted the extra Kth entry of the unconstrained vector, this type entry has been removed. Since the Jacobian is now non-square, triangular return types are no longer used. As a result, the change is marked as breaking.

sethaxen commented 1 year ago

Currently tests fails due to these lines, which seem to assume inputs and outputs are the same size

On Slack, @torfjelde confirmed that these should be fixed.

torfjelde commented 1 year ago

Haven't forgotten about this, but the DPPL integration was set back significantly by some other changes we had made. Should be done soon now :+1:

torfjelde commented 1 year ago

Aaaalrighty! Does someone want to give this a look-over? I think pasts will pass now, an so it would be nice to get this merged.

sethaxen commented 1 year ago

Thanks @torfjelde for the fixes! All LGTM, but someone else needs to review.

torfjelde commented 1 year ago

It's a breaking change it seems, so IMO it would be good to include the correct version bump in the PR to avoid accidentally tagging a non-breaking release.

We haven't released #master yet, which has been bumped accordingly:) I'm defering release of #master until both this and #271 have gone through.