arviz-devs / InferenceObjects.jl

Storage for results of Bayesian inference
https://julia.arviz.org/InferenceObjects
MIT License
14 stars 1 forks source link

Utility for unflattening Datasets #27

Open sethaxen opened 2 years ago

sethaxen commented 2 years ago

The natural way to represent a draw from a posterior distribution is as a NamedTuple whose keys are parameter names and whose values are the values. The values can be scalars, arrays, or arbitrary Julia objects. Then all draws for a chain are a vector of such NamedTuples, and we may have a vector of chains. When we convert to InferenceData, we would "flatten" until we get numeric arrays. Each element of such an array is a marginal draw, and this is useful for plotting and diagnostics.

Sometimes though users need the unflattened draws; e.g., when interacting with the PPL, one often needs draws in a format produced by the PPL, which will in general not look like a Dataset. In #11 we discuss ideas for not flattening. A simpler alternative is to provide utility functions for "unflattening". Here's an example of such a function:

julia> using DimensionalData, InferenceObjects

julia> function unflatten(f, v, keep_dims=(:chain, :draw))
           dims = Dimensions.otherdims(v, keep_dims)
           isempty(dims) && return v
           keep_dims_actual = Dimensions.otherdims(v, dims)
           dimnums = Dimensions.dimnum(v, dims)
           data_new = dropdims(mapslices(Base.vect ∘ f, parent(v); dims=dimnums); dims=dimnums)
           return DimArray(data_new, keep_dims_actual)
       end;

By passing f=identity, we can handle the case where draws are scalars or arrays of scalars:

julia> x = convert_to_dataset((; x=randn(2, 3, 8, 4)); dims=(x=[:a, :b],)).x
2×3×8×4 DimArray{Float64,4} x with dimensions: 
  Dim{:a} Sampled{Int64} Base.OneTo(2) ForwardOrdered Regular Points,
  Dim{:b} Sampled{Int64} Base.OneTo(3) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
[:, :, 1, 1]
     1         2         3
 1   0.844121  1.79069  -0.349435
 2  -0.435955  2.21937   0.102086
[and 31 more slices...]

julia> x_unflat = unflatten(identity, x)
8×4 DimArray{Matrix{Float64},2} with dimensions: 
  Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
    …  4
 1      [-0.710667 2.17712 1.25004; 1.15662 0.138343 0.511868]
 2      [-1.59593 -0.847627 0.185637; 1.62011 0.733101 -0.82679]
 3      [1.6434 0.188635 0.434926; -2.70953 0.223494 -1.00055]
 4      [-0.753704 -2.25251 0.32903; 1.97774 -0.744595 1.0287]
 5  …   [0.837521 -0.252849 0.0989726; -1.10382 0.511166 0.566629]
 6      [-1.58429 -0.164573 1.83263; 0.875992 -0.174146 -1.10488]
 7      [-2.21422 -0.398891 -1.26135; 1.27395 -0.150042 0.243492]
 8      [0.789781 0.052268 -1.51552; 0.5554 1.08581 -1.16574]

julia> x_unflat[1]
2×3 Matrix{Float64}:
  0.844121  1.79069  -0.349435
 -0.435955  2.21937   0.102086

Other fs let us handle cases where draws are not array types. For example, here's how we might unflatten a real array representing complex draws:

julia> z = convert_to_dataset((; z=randn(2, 8, 4)); dims=(z=[:reim],), coords=(reim=[:re, :im],)).z
2×8×4 DimArray{Float64,3} z with dimensions: 
  Dim{:reim} Categorical{Symbol} Symbol[re, im] ReverseOrdered,
  Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
[:, :, 1]
        1           2          3          4          5          6          7         8
  :re   0.149895   -0.758902  -0.162169  -1.58568   -1.9113    -0.873895  -1.15336  -0.723117
  :im  -0.0615223  -0.191197   0.552402   0.754498  -0.139014   0.496133   1.69164  -1.05489
[and 3 more slices...]

julia> z_unflat = unflatten(Base.splat(complex), z)
8×4 DimArray{ComplexF64,2} with dimensions: 
  Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
             1                      2                      3                       4
 1   0.149895-0.0615223im   -1.49337+0.0310133im  -0.186236-0.632437im     0.122908-1.8747im
 2  -0.758902-0.191197im   0.0847507+1.8477im      0.699646+0.0940246im   -0.700787+0.589689im
 3  -0.162169+0.552402im   -0.426661-0.215763im     1.24455-0.30482im       0.87671-0.0396714im
 4   -1.58568+0.754498im    -1.08887-0.0911398im   -1.18796+0.0439568im    0.583836+0.226613im
 5    -1.9113-0.139014im    -1.11748+0.521976im   -0.453853-0.668656im     -1.40155+0.216688im
 6  -0.873895+0.496133im    0.471934+0.508555im     -1.1003-0.844055im       2.6073-0.25573im
 7   -1.15336+1.69164im     0.107038+0.070659im    -2.15358-1.19693im    -0.0646238-0.749879im
 8  -0.723117-1.05489im       1.0455-0.601896im   -0.931837+0.621233im     0.789712+0.442579im

By applying this approach to all parameters in a Dataset, we can unflatten everything:

julia> using ArviZExampleData

julia> idata = load_example_data("centered_eight");

julia> post = idata.posterior
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} Int64[0, 1, …, 498, 499] ForwardOrdered Irregular Points,
  Dim{:chain} Sampled{Int64} Int64[0, 1, 2, 3] ForwardOrdered Irregular Points,
  Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
and 3 layers:
  :mu    Float64 dims: Dim{:draw}, Dim{:chain} (500×4)
  :theta Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×500×4)
  :tau   Float64 dims: Dim{:draw}, Dim{:chain} (500×4)

with metadata Dict{String, Any} with 6 entries:
  "created_at"                => "2022-10-13T14:37:37.315398"
  "inference_library_version" => "4.2.2"
  "sampling_time"             => 7.48011
  "tuning_steps"              => 1000
  "arviz_version"             => "0.13.0.dev0"
  "inference_library"         => "pymc"

julia> post_new = Dataset(map(v -> unflatten(identity, v), NamedTuple(post)))
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} Int64[0, 1, …, 498, 499] ForwardOrdered Irregular Points,
  Dim{:chain} Sampled{Int64} Int64[0, 1, 2, 3] ForwardOrdered Irregular Points
and 3 layers:
  :mu    Float64 dims: Dim{:draw}, Dim{:chain} (500×4)
  :theta Vector{Float64} dims: Dim{:draw}, Dim{:chain} (500×4)
  :tau   Float64 dims: Dim{:draw}, Dim{:chain} (500×4)

julia> post_new[1]
(mu = 7.871796366146925, theta = [12.320685578094814, 9.905366892588605, 14.9516154956564, 11.011484941973162, 5.5796015919074735, 16.901795293711004, 13.198059333176934, 15.06136583596694], tau = 4.725740062893666)

I propose we add something like this utility to the API to make it easier to use InferenceObjects with PPLs.