arviz-devs / InferenceObjects.jl

Storage for results of Bayesian inference
https://julia.arviz.org/InferenceObjects
MIT License
14 stars 1 forks source link

Better supporting draws that are arbitrary Julia types #11

Open sethaxen opened 2 years ago

sethaxen commented 2 years ago

This issue continues discussion starting at https://github.com/arviz-devs/InferenceObjects.jl/issues/8#issuecomment-1223007423.

Some Julia PPLs can return draws as arbitrary Julia types. Here's an example with Soss:

```julia julia> using Soss julia> struct Foo x y tag end julia> mod1 = @model n begin x ~ Normal() y ~ Normal() |> iid(n) return Foo(x, y, :discrete) end; julia> mod2 = @model n begin t ~ mod1(n) z ~ Normal(μ=t.x) end; julia> rand(mod2(3)) (t = Foo(-1.8692299945695137, [-0.9020275201942468, -0.9380392196474631, -0.041490841817294566], :discrete), z = -3.0169797376605327) ```

Currently such types can be stored in InferenceData:

```julia julia> using InferenceObjects julia> data = (; a = [rand(mod2(3)) for _ in 1:4, _ in 1:100]); julia> idata = InferenceData(posterior=namedtuple_to_dataset(data)) InferenceData with groups: > posterior julia> idata.posterior Dataset with dimensions: Dim{:chain} Sampled Base.OneTo(4) ForwardOrdered Regular Points, Dim{:draw} Sampled Base.OneTo(100) ForwardOrdered Regular Points and 1 layer: :a NamedTuple{(:t, :z), Tuple{Foo, Float64}} dims: Dim{:chain}, Dim{:draw} (4×100) with metadata OrderedCollections.OrderedDict{Symbol, Any} with 1 entry: :created_at => "2022-08-25T11:15:48.582" julia> idata.posterior.a 4×100 DimArray{NamedTuple{(:t, :z), Tuple{Foo, Float64}},2} a with dimensions: Dim{:chain} Sampled Base.OneTo(4) ForwardOrdered Regular Points, Dim{:draw} Sampled Base.OneTo(100) ForwardOrdered Regular Points … 100 1 (t = Foo(-0.332783, [-0.271914, -1.19732, 0.239832], :discrete), z = -0.473026) 2 (t = Foo(-1.03842, [-0.148646, -0.102317, 0.242476], :discrete), z = -1.53602) 3 (t = Foo(0.902033, [-0.798571, 0.173176, 0.533269], :discrete), z = 1.45302) 4 (t = Foo(-0.98641, [1.89491, -0.674791, -0.203847], :discrete), z = -1.44689) ```

So InferenceData can be used for this storage, but it's not very useful, for several reasons:

Here's an example of what the Tables interface would produce:

```julia julia> using DataFrames julia> DataFrame(idata.posterior) 400×3 DataFrame Row │ chain draw a │ Int64 Int64 NamedTup… ─────┼───────────────────────────────────────────────── 1 │ 1 1 (t = Foo(0.235607, [1.08405, 0.9… 2 │ 2 1 (t = Foo(-1.24972, [-1.89301, 0.… 3 │ 3 1 (t = Foo(1.0526, [-0.179664, -0.… 4 │ 4 1 (t = Foo(0.793393, [0.558985, 0.… ⋮ │ ⋮ ⋮ ⋮ 398 │ 2 100 (t = Foo(-1.03842, [-0.148646, -… 399 │ 3 100 (t = Foo(0.902033, [-0.798571, 0… 400 │ 4 100 (t = Foo(-0.98641, [1.89491, -0.… 393 rows omitted ```

So plotting packages that use the Tables interface, like AlgebraOfGraphics and StatsPlots, are not terribly useful here without lots of additional code.

There are several ways we might approach this:

  1. Do nothing. Users are free to use arbitrary types with InferenceData, and they are expected to turn their types into whatever marginals they care about when they want to use the downstream functions we discussed above. This is the current state.
  2. Require all converters flatten to the marginals. The converter might encode some of the structure into the Dataset. e.g. the above example might be converted to a Dataset with variable names a.t.x, a.t.y, a.t.tag, and a.z. If we go this route, InferenceData would be a secondary data type used only for some analyses but not a possible default for such PPLs, since it loses some of the structure in the initial draws.
  3. Define an interface for computing a "marginal representation" of a variable, dataset, or whole InferenceData. This would be called by the user to convert a non-flattened InferenceData to a flattened one, allowing provision of named dimensions. e.g. such a function would map the above posterior to something like:
```julia julia> using Compat julia> a = idata.posterior.a; julia> d = (; var"a.t.x"=map(x -> x.t.x, a), var"a.t.y"=permutedims(Compat.stack(map(x -> x.t.y, a)), (2, 3, 1)), var"a.t.tag"=map(x -> x.t.tag, a), var"a.z"=map(x -> x.z, a), ); julia> post_new = namedtuple_to_dataset(d) Dataset with dimensions: Dim{:chain} Sampled Base.OneTo(4) ForwardOrdered Regular Points, Dim{:draw} Sampled Base.OneTo(100) ForwardOrdered Regular Points, Dim{:a.t.y_dim_1} Sampled Base.OneTo(3) ForwardOrdered Regular Points and 4 layers: :a.t.x Float64 dims: Dim{:chain}, Dim{:draw} (4×100) :a.t.y Float64 dims: Dim{:chain}, Dim{:draw}, Dim{:a.t.y_dim_1} (4×100×3) :a.t.tag Symbol dims: Dim{:chain}, Dim{:draw} (4×100) :a.z Float64 dims: Dim{:chain}, Dim{:draw} (4×100) with metadata OrderedCollections.OrderedDict{Symbol, Any} with 1 entry: :created_at => "2022-08-25T12:08:52.898" julia> DataFrame(post_new) 1200×7 DataFrame Row │ chain draw a.t.y_dim_1 a.t.x a.t.y a.t.tag a.z │ Int64 Int64 Int64 Float64 Float64 Symbol Float64 ──────┼──────────────────────────────────────────────────────────────────────── 1 │ 1 1 1 0.235607 1.08405 discrete 1.20053 2 │ 2 1 1 -1.24972 -1.89301 discrete 0.0482549 3 │ 3 1 1 1.0526 -0.179664 discrete 4.70509 4 │ 4 1 1 0.793393 0.558985 discrete 0.919428 5 │ 1 2 1 -1.53217 1.41717 discrete -1.89603 ⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ 1196 │ 4 99 3 -1.65774 0.689958 discrete -1.00366 1197 │ 1 100 3 -0.332783 0.239832 discrete -0.473026 1198 │ 2 100 3 -1.03842 0.242476 discrete -1.53602 1199 │ 3 100 3 0.902033 0.533269 discrete 1.45302 1200 │ 4 100 3 -0.98641 -0.203847 discrete -1.44689 1190 rows omitted ```

The easiest way I can think of to provide such a default is to recur through all Julia types and allocate new arrays as done above, but there may be other options using custom Julia arrays. @oschulz, @cscherrer, would the types you have been suggesting allow for this?

  1. Develop a completely new data structure that allows arbitrarily deep nesting of data types and assignment of dimensions at any level but also implements both a marginal no-copy view that flattens everything and a tabular no-copy view that further concatenates with useful column names. I don't see immediately how this could be done with existing dimensional data types, so it could be as complicated as developing yet another dimensional data package.

Off the top of my head, a few additional criteria for the solution:

  1. The InferenceData type and its basic functionality must be kept in a lightweight package and as generic as possible. It's not even ideal that we depend on DimensionalData, but so we do. If we require a complicated solution with lots of dependencies, this should be its own package, which PPLs or packages with PPL-specific converters can then depend on.
  2. We can allow type piracy for packages within this organization if necessary, but that's it.
  3. The solution should ideally not require the average user or PPL developer to implement some API for their custom types, i.e. there should be sensible defaults.
  4. While we're focusing on increasing usability within Julia, we cannot sacrifice serialization to data structures like NetCDF for archiving and interop with other languages.

Since the others tagged in this have thought a lot more about this than I have, I'd appreciate any input/suggestions. cc also @ParadaCarleton

sethaxen commented 2 years ago

@femtomc I wonder if you have input on this as well, since I think Gen traces also can be nested and contain arbitrary Julia types.

oschulz commented 2 years ago

Something that may be helpful in this context: @cscherrer and me had discussed to built flatten/unflatten transformations on top of the now transport API in MeasureBase.jl. This would allow for automatically generating transforms to/from flat vectors as long as a prior measure is available (it would provide the required structural information).

oschulz commented 2 years ago

Another thing that may be interesting in this contect: In BAT.jl we've recently added the ability to marginalize/flatten structures to "flat" NamedTuples using unicode. This is currently limited to non-nested input, but the result looks like this: A value (a = [1.2, 2.3], b = 4.2) can be turned into (a⌞1⌟ = 1.2, a⌞2⌟ = 2.3, b = 4.2). We use a few other unicode characters too so we can preserve range-selection during marginalization and have valid unicode field names like (d⌞1ː2⌟ = ...). We introduced it to support value selection for plotting, but we're planning to extend it and make it more directly accessible. Maybe such a "flatten-nested-names-and-ranges-to-unicode" scheme could be useful for arviz as well?

cscherrer commented 2 years ago
Some more transform discussion (hidden to avoid distracting) I think it's also worth noting that transforms/transports can be data-dependent. I don't know how Turing does things, but TransformVariables wants the transform to be static. For Tilde, I think we can make things a lot more flexible. Instead of running the model once to determine the transport, the transform can be more dynamic and be itself in terms of a model run. For inference, the simple approach would then run the model twice: once to get the transformed value, and again to get the log-density. But that's inefficient, so I think we'll compute the log-density along the way. If there's a Vector of samples available, we'll write into that as we go.

All downstream diagnostics, statistics, plots, and serialization to NetCDF/Zarr will require access to marginals, so we need flat multidimensional arrays, often with numeric types.

@sethaxen I think of TupleVectors as making it easy to get to marginals, so maybe I don't understand what you mean by "marginals". Can you give more detail on this?