arviz-devs / ArviZ.jl

Exploratory analysis of Bayesian models with Julia
https://julia.arviz.org
Other
107 stars 10 forks source link

Toward a completely PPL-agnostic Bayesian workflow #154

Open sethaxen opened 2 years ago

sethaxen commented 2 years ago

This first post is a stream-of-consciousness dump of some ideas I've been tossing around in my head for a few months now. I'll edit it for clarity as needed.

The goal

To construct a small ecosystem of completely PPL-agnostic functions for diagnostics, statistics, and plotting in a typical Bayesian workflow.

The problem

Some solutions

I am becoming more convinced that the latter is what we need.

An InferenceData interface?

There are already a few relevant interfaces in the ecosystem that we should mention.

What we need is something a little different than all of these, but likely with some overlap that will let us reuse some of these interfaces. An arbitrary diagnostic/statistic/plotting function may need access to one or more of the following (non-exhaustive list):

The idea is that all or part of the above interface can be implemented for objects returned by a PPL. e.g. an MCMCChains.Chains object could implement the parts corresponding to raw samples, while a Soss.Model or a DynamicPPL.Model would implement the part corresponding to the log density function, and then the objects could be passed directly to certain functions already. But in many cases one would like some features from both the samples themselves and the model object and the data, so I think what we're looking for is some super-object that ties together these objects to represent the current state of inference and implements the entire interface. Such an object could then be passed to any generic function, which could even update the object to store generated quantities like prior samples and void re-drawing them.

See e.g. https://github.com/TuringLang/AbstractPPL.jl/issues/27 for an example of how one might implement simulation-based calibration (SBC) with just generic interface functions.

Some considerations:

ParadaCarleton commented 2 years ago

What would such an interface entail?

I ask because, apart from MCMCChains, I don't think there's any other that you'd need to using an interface, and it's already easy enough to convert from MCMCChains to Arviz.jl's InferenceData. I'm actually more interested in whether there's actually a need for keeping MCMCChains separate from ArviZ.jl, or if in theory you could eventually merge them, the same way several Python PPLs switched from their own representations to using Arviz's.

Maybe just starting with #128 would be enough for now?

sethaxen commented 2 years ago

What would such an interface entail?

I ask because, apart from MCMCChains, I don't think there's any other that you'd need to using an interface, and it's already easy enough to convert from MCMCChains to Arviz.jl's InferenceData. I'm actually more interested in whether there's actually a need for keeping MCMCChains separate from ArviZ.jl, or if in theory you could eventually merge them, the same way several Python PPLs switched from their own representations to using Arviz's.

Storage is only part of the picture. There are diagnostics like SBC that require access to some model object, the ability to condition it, and the ability to sample from it with an inference method (see https://github.com/TuringLang/AbstractPPL.jl/issues/27 for a description of the minimal functionality). Doing so in a PPL-agnostic, user-friendly way requires an interface. So in short, no, MCMCChains does not do everything we need. AbstractPPL comes closer.

But even for storage, MCMCChains doesn't give us everything we need (yet). The main issue is that Chains linearizes everything. So if one of your parameters is a 3D array, you as the user can't just get that 3D array (with 2 other dimensions for iteration and chain) from Chains without some boilerplate (see https://github.com/TuringLang/MCMCChains.jl/issues/305). ArviZ on the other hand prioritizes parameters that have arbitrary numbers of dimensions, as well as named dims/coords; the latter are not only useful for analysis of large models but also for making more interpretable plots.

But even ArviZ falls short in the storage area, because it's still biased toward arrays. But in principle, one can have parameters be any Julia object (e.g. Distributions.LKJCholesky produces draws of Cholesky objects). One needs to be able to linearize to compute certain diagnostics/plot, but discarding the structure of such objects is not ideal.

@yebai and I talked briefly about this today. We think there might be room for MCMCChains to solve both these issues, in which case we could use it here, and I'll open an issue on the MCMCChains repo to discuss further.

InferenceData would still go a bit further in tying together storage of objects (e.g. posterior and prior draws) needed at different stages of the workflow.

Maybe just starting with https://github.com/arviz-devs/ArviZ.jl/issues/128 would be enough for now?

Yeah, I'm thinking of working on #128 sooner rather than later, using DimensionalData as the base. As long as we only interact with the storage object through our own (maybe internal) interface, then we can always switch as needed, and this way Julia users get more Julia-friendly storage immediately.

ParadaCarleton commented 2 years ago

What would such an interface entail? I ask because, apart from MCMCChains, I don't think there's any other that you'd need to using an interface, and it's already easy enough to convert from MCMCChains to Arviz.jl's InferenceData. I'm actually more interested in whether there's actually a need for keeping MCMCChains separate from ArviZ.jl, or if in theory you could eventually merge them, the same way several Python PPLs switched from their own representations to using Arviz's.

Storage is only part of the picture. There are diagnostics like SBC that require access to some model object, the ability to condition it, and the ability to sample from it with an inference method (see TuringLang/AbstractPPL.jl#27 for a description of the minimal functionality). Doing so in a PPL-agnostic, user-friendly way requires an interface. So in short, no, MCMCChains does not do everything we need. AbstractPPL comes closer.

But even for storage, MCMCChains doesn't give us everything we need (yet). The main issue is that Chains linearizes everything. So if one of your parameters is a 3D array, you as the user can't just get that 3D array (with 2 other dimensions for iteration and chain) from Chains without some boilerplate (see TuringLang/MCMCChains.jl#305). ArviZ on the other hand prioritizes parameters that have arbitrary numbers of dimensions, as well as named dims/coords; the latter are not only useful for analysis of large models but also for making more interpretable plots.

But even ArviZ falls short in the storage area, because it's still biased toward arrays. But in principle, one can have parameters be any Julia object (e.g. Distributions.LKJCholesky produces draws of Cholesky objects). One needs to be able to linearize to compute certain diagnostics/plot, but discarding the structure of such objects is not ideal.

@yebai and I talked briefly about this today. We think there might be room for MCMCChains to solve both these issues, in which case we could use it here, and I'll open an issue on the MCMCChains repo to discuss further.

InferenceData would still go a bit further in tying together storage of objects (e.g. posterior and prior draws) needed at different stages of the workflow.

Maybe just starting with #128 would be enough for now?

Yeah, I'm thinking of working on #128 sooner rather than later, using DimensionalData as the base. As long as we only interact with the storage object through our own (maybe internal) interface, then we can always switch as needed, and this way Julia users get more Julia-friendly storage immediately.

Right, ok, I think I understand better what you mean — you’re referring more to AbstractPPL-like interfaces to deal with models and such, rather than replacing InferenceData with an interface that would let you deal with any set of objects like MCMCChains or Soss’ TupleVectors?

OriolAbril commented 2 years ago

But even ArviZ falls short in the storage area, because it's still biased toward arrays. But in principle, one can have parameters be any Julia object (e.g. Distributions.LKJCholesky produces draws of Cholesky objects). One needs to be able to linearize to compute certain diagnostics/plot, but discarding the structure of such objects is not ideal.

Is that something that could potentially be solved with some kind of metadata as in https://github.com/arviz-devs/arviz/issues/1975?

One of the goals of InferenceData is that is easy to save as a file and load again (potentially in other languages) and afaik netcdf and zarr only support arrays. In fact, in my opinion the loaded object doesn't even need to be an "InferenceData" nor all groups need to be loaded as part of the same object. Therefore, if we have a way to read/write those non-array objects and do any conversion needed in the from/to_netcdf/zarr methods, I don't see this as something different or incompatible with InferenceData.

ParadaCarleton commented 2 years ago

But even ArviZ falls short in the storage area, because it's still biased toward arrays. But in principle, one can have parameters be any Julia object (e.g. Distributions.LKJCholesky produces draws of Cholesky objects). One needs to be able to linearize to compute certain diagnostics/plot, but discarding the structure of such objects is not ideal.

@sethaxen what about StructArrays? That sounds pretty similar.

sethaxen commented 2 years ago

Is that something that could potentially be solved with some kind of metadata as in arviz-devs/arviz#1975?

I think this would be solving a slightly different problem. This would be allowing for conversion between a convenient linearized representation and the "natural" (i.e. user-defined representation) of a draw of a parameter. That does not necesarily contain semantic information about what constraints there are on the parameter or what its type is.

One of the goals of InferenceData is that is easy to save as a file and load again (potentially in other languages) and afaik netcdf and zarr only support arrays. In fact, in my opinion the loaded object doesn't even need to be an "InferenceData" nor all groups need to be loaded as part of the same object. Therefore, if we have a way to read/write those non-array objects and do any conversion needed in the from/to_netcdf/zarr methods, I don't see this as something different or incompatible with InferenceData.

Hm, good point, we need to think also about serialization. The tricky thing with supporting such conversions to a natural representation is that this would be inherently Julia-specific (because the storage would be a Julia object), so while we could serialize the linearized representation for loading from other languages, I think this would lose the ability to convert back to a natural one.

@sethaxen what about StructArrays? That sounds pretty similar.

So something like StructArrays (or ComponentArrays) is one way to go. StructArrays doesn't do recursive conversion though, so you'd need to tell it to on a per-type basis, which may make using it generically hard. e.g.

julia> using StructArrays, LinearAlgebra, Distributions

julia> dx = Normal();

julia> dy = MvNormal([0, 0]);

julia> dF = LKJCholesky(5, 2.0);

julia> draws = StructArray([(x=rand(dx), y=rand(dy), F=rand(dF)) for chain in 1:2, draw in 1:10]);

julia> StructArrays.components(draws).F |> typeof
Matrix{Cholesky{Float64, Matrix{Float64}}} (alias for Array{Cholesky{Float64, Array{Float64, 2}}, 2})

julia> draws = StructArray([(x=rand(dx), y=rand(dy), F=rand(dF)) for chain in 1:2, draw in 1:10]; unwrap=T->T <: Cholesky);

julia> StructArrays.components(draws).F |> typeof
StructArray{Cholesky{Float64, Matrix{Float64}}, 2, NamedTuple{(:factors, :uplo, :info), Tuple{Matrix{Matrix{Float64}}, Matrix{Char}, Matrix{Int64}}}, Int64}

julia> StructArray([(x=rand(dx), y=rand(dy), F=rand(dF)) for chain in 1:2, draw in 1:10]; unwrap=T->true);
ERROR: BoundsError: attempt to access 0-element StructArray(StructArray(), StructArray(), StructArray(StructArray(), StructArray(), StructArray())) with eltype NamedTuple{(:x, :y, :F), Tuple{Float64, Vector{Float64}, Cholesky{Float64, Matrix{Float64}}}} with indices 1:0 at index [1]
Stacktrace:

It also doesn't play super well with array types with named dims/coords. e.g.

julia> using DimensionalData

julia> chains = 1:2
1:2

julia> iters = 1:10
1:10

julia> draws = StructArray(DimArray([(x=rand(dx), y=rand(dy), F=rand(dF)) for chain in chains, iter in iters], (chain=chains, iter=iters)); unwrap=T->T <: Cholesky);

julia> draws[chain=1, iter=1]
ERROR: MethodError: no method matching getindex(::StructArray{NamedTuple{(:x, :y, :F), Tuple{Float64, Vector{Float64}, Cholesky{Float64, Matrix{Float64}}}}, 2, NamedTuple{(:x, :y, :F), Tuple{DimArray{Float64, 2, Tuple{Dim{:chain, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}, Dim{:iter, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}}, Tuple{}, Matrix{Float64}, DimensionalData.NoName, DimensionalData.Dimensions.LookupArrays.NoMetadata}, DimArray{Vector{Float64}, 2, Tuple{Dim{:chain, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}, Dim{:iter, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}}, Tuple{}, Matrix{Vector{Float64}}, DimensionalData.NoName, DimensionalData.Dimensions.LookupArrays.NoMetadata}, StructArray{Cholesky{Float64, Matrix{Float64}}, 2, NamedTuple{(:factors, :uplo, :info), Tuple{DimArray{Matrix{Float64}, 2, Tuple{Dim{:chain, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}, Dim{:iter, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}}, Tuple{}, Matrix{Matrix{Float64}}, DimensionalData.NoName, DimensionalData.Dimensions.LookupArrays.NoMetadata}, DimArray{Char, 2, Tuple{Dim{:chain, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}, Dim{:iter, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}}, Tuple{}, Matrix{Char}, DimensionalData.NoName, DimensionalData.Dimensions.LookupArrays.NoMetadata}, DimArray{Int64, 2, Tuple{Dim{:chain, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}, Dim{:iter, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}}, Tuple{}, Matrix{Int64}, DimensionalData.NoName, DimensionalData.Dimensions.LookupArrays.NoMetadata}}}, Int64}}}, Int64}; chain=1, iter=1)
Closest candidates are:
  getindex(::StructArray{T, <:Any, <:Any, Int64}, ::Int64) where T at ~/.julia/packages/StructArrays/0C03x/src/structarray.jl:343 got unsupported keyword arguments "chain", "iter"
  getindex(::AbstractArray, ::Any...) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/abstractarray.jl:1215 got unsupported keyword arguments "chain", "iter"
Stacktrace:
 [1] top-level scope
   @ REPL[69]:1

julia> draws.F
2×10 StructArray(::DimArray{Matrix{Float64}, 2, Tuple{Dim{:chain, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}, Dim{:iter, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}}, Tuple{}, Matrix{Matrix{Float64}}, DimensionalData.NoName, DimensionalData.Dimensions.LookupArrays.NoMetadata}, ::DimArray{Char, 2, Tuple{Dim{:chain, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}, Dim{:iter, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}}, Tuple{}, Matrix{Char}, DimensionalData.NoName, DimensionalData.Dimensions.LookupArrays.NoMetadata}, ::DimArray{Int64, 2, Tuple{Dim{:chain, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}, Dim{:iter, DimensionalData.Dimensions.LookupArrays.Sampled{Int64, UnitRange{Int64}, DimensionalData.Dimensions.LookupArrays.ForwardOrdered, DimensionalData.Dimensions.LookupArrays.Regular{Int64}, DimensionalData.Dimensions.LookupArrays.Points, DimensionalData.Dimensions.LookupArrays.NoMetadata}}}, Tuple{}, Matrix{Int64}, DimensionalData.NoName, DimensionalData.Dimensions.LookupArrays.NoMetadata}) with eltype Cholesky{Float64, Matrix{Float64}}:
 Cholesky{Float64, Matrix{Float64}}([1.0 3.5e-323 … 2.59608e-314 6.4e-323; 0.292908 0.956141 … 6.0e-323 7.0e-323; … ; -0.0770292 0.663924 … 0.737009 2.59608e-314; -0.453924 -0.40182 … -0.647584 0.445791], 'L', 0)      …  Cholesky{Float64, Matrix{Float64}}([1.0 2.2338e-308 … 1.33143e-315 0.0; -0.279271 0.960212 … 2.58656e-231 0.0; … ; 0.235842 -0.00327757 … 0.957821 0.0; 0.0696024 0.288193 … 0.0352111 0.800725], 'L', 0)
 Cholesky{Float64, Matrix{Float64}}([1.0 2.32327e-314 … 1.3e-322 2.32327e-314; -0.246615 0.969114 … 3.0e-323 1.93e-322; … ; -0.0767973 0.495859 … 0.833087 2.32326e-314; 0.137547 0.407415 … 0.550548 0.606197], 'L', 0)     Cholesky{Float64, Matrix{Float64}}([1.0 2.18081e-314 … 2.18082e-314 2.18082e-314; 0.437534 0.899202 … 2.18082e-314 2.18082e-314; … ; 0.167467 0.127887 … 0.97316 2.18082e-314; 0.525501 -0.65536 … 0.404307 0.250537], 'L', 0)

julia> draws.x   # non-recursive types keep the names though
2×10 DimArray{Float64,2} with dimensions: 
  Dim{:chain} Sampled 1:2 ForwardOrdered Regular Points,
  Dim{:iter} Sampled 1:10 ForwardOrdered Regular Points
  0.804066  -0.73135    0.23626   -0.607354  2.00062  -1.26281   -0.341449  0.220834  -0.217121   0.266102
 -0.778128  -0.415387  -0.830491   0.148016  1.21364  -0.756179   2.11044   0.200427  -1.8515    -1.24878

julia> draws.y  # oh, but this is an array-of-arrays, so we couldn't specify its named dims/coords
2×10 DimArray{Vector{Float64},2} with dimensions: 
  Dim{:chain} Sampled 1:2 ForwardOrdered Regular Points,
  Dim{:iter} Sampled 1:10 ForwardOrdered Regular Points
 [-0.0, 0.0]   [-0.0, -0.0]  [0.0, 0.0]   [0.0, -0.0]  [-0.0, 0.0]  [-0.0, -0.0]  [0.0, -0.0]  [0.0, 0.0]    [0.0, -0.0]  [-0.0, 0.0]
 [-0.0, -0.0]  [-0.0, 0.0]   [0.0, -0.0]  [0.0, 0.0]   [-0.0, 0.0]  [0.0, 0.0]    [-0.0, 0.0]  [-0.0, -0.0]  [0.0, 0.0]   [0.0, 0.0]

This will all get even messier when trying to implement downstream plotting, diagnostics packages, because StructArrays can only be added to each other if all constituent types support addition, which in this case they don't. So one would need to special-case all diagnostics code for StructArrays.

Thinking out loud here, another thing we could do is automatically flatten all arrays of numeric arrays to numeric arrays (as we currently do) and (try to) wrap everything else recursively in StructArrays. Then when we have named dims/coords, we could apply those on the outside of the StructArray. All of the parameters are then wrapped in something like DimensionalData.DimStack, and here we only need to special-case whenever we encounter a DimArray whose eltype is not a Real subtype. We could even for a first pass exclude the other types for diagnostic computation/plotting, i.e. we support storing them but not yet performing analyses on them.

OriolAbril commented 2 years ago

What are the main conversion challenges you expect between "natural" and linearized representations?

Not sure if I'm being too naive but it would be nice for the goal to be a file format/specification that allows anyone to publish a netcdf/zarr file along with their paper so that we can reproduce the analysis of their results without sampling.

ParadaCarleton commented 2 years ago

So something like StructArrays (or ComponentArrays) is one way to go. StructArrays doesn't do recursive conversion though, so you'd need to tell it to on a per-type basis, which may make using it generically hard. e.g.

Do we want to do recursive conversion? It feels to me that having arrays bundled like this is an advantage rather than disadvantage, since I often want to analyze a whole bunch of parameters stored in an array together.

This will all get even messier when trying to implement downstream plotting, diagnostics packages, because StructArrays can only be added to each other if all constituent types support addition, which in this case they don't. So one would need to special-case all diagnostics code for StructArrays.

Do we actually need to add StructArrays at any point? In theory all these analyses should be done componentwise, and each component should have a well-defined addition/subtraction operator. And if necessary we can define something like an "AlmostStructArray" which behaves the same way but defines addition to be done componentwise.

Then when we have named dims/coords, we could apply those on the outside of the StructArray.

That sounds good to me.

We could even for a first pass exclude the other types for diagnostic computation/plotting, i.e. we support storing them but not yet performing analyses on them.

If they cause problems, we can work with the reals for now and add support for other types later.