tpapp / TransformVariables.jl

Transformations to contrained variables from ℝⁿ.
Other
66 stars 14 forks source link

Feature Idea: flat transform #94

Open scheidan opened 2 years ago

scheidan commented 2 years ago

It would be useful if transform would have an option, so that the result remains a flat vector:

transform(t, x, keep_flat=true)

A good use case is converting MCMC samples in MCMCChains.Chains objects:

import MCMCChains

samp = rand(1000, 5)            # we would get them this from a MCMC algorithm

# we get an array of named tuples, which is great to define the model but difficult to convert a `Chain`.
samp_trans1 = mapslices(s -> transform(t, s), samp, dims=2)
MCMCChains.Chains(samp_trans)              # fails

# with the new argument we would get an array
samp_trans2 = mapslice(s -> transform(t, s, keep_flat=true), samp, dims=2)
MCMCChains.Chains(samp_trans2)              # that would work

This seem related to #13

tpapp commented 2 years ago

What is the format of samp_trans2 here that you would expect? I am not familiar with MCMCChains.Chains.

scheidan commented 2 years ago

MCMCChains.Chain expects an Array of dimensions iterations × n_parameters (or iterations × n_parameters × n_chains).

Having a flat transform would make the construction of such an Array quite easy. We would need to be careful with the length:

t = as((a = asℝ,
        b = as(Vector, as(Real, 0, 1), 2),
        c = UnitVector(3)))

x = randn(dimension(t))  # length(x) == 5
transform(t, x)  # -> tuple
transform(t, x, keep_flat=true))  # -> vector of length(6) != dimension(t)
tpapp commented 2 years ago

Thanks, I get it. It should be relatively easy to flatten transformed values:

flatten(x::Real) = [x]
flatten(x::AbstractArray) = vec(x)
flatten(x::Tuple) = mapreduce(flatten, vcat, x)
flatten(x::NamedTuple) = mapreduce(flatten, vcat, values(x))

z = (a = 1.0, b = [2.0, 3.0], c = (d = 4.0, e = 5.0))

flatten(z)

can deal with everything TransformVariables can dish out at the moment. (The code above necessarily allocates and is quite suboptimal, in the ideal case this would be done with views like https://github.com/JuliaArrays/StackViews.jl).

Or would you prefer transforming directly to a flat vector for efficiency? I will keep this in mind for the next refactoring (which is coming up soon).

tpapp commented 2 years ago

Also, an ideal API would give column names, such as [:a, :b_1, :b_2, :c_d, :c_e] or similar.

scheidan commented 2 years ago

Getting meaningful names would be very helpful!

MCMCChains.jl has some support for names with brackets, for variables from arrays e.g. "x[1,1]", "x[1,2]" https://beta.turing.ml/MCMCChains.jl/stable/getting-started/#Groups-of-parameters