TuringLang / MCMCDiagnosticTools.jl

https://turinglang.org/MCMCDiagnosticTools.jl/dev
Other
19 stars 6 forks source link

Supporting more shapes #78

Closed sethaxen closed 1 year ago

sethaxen commented 1 year ago

Currently the more modern methods support the following shapes:

Looking at these, I think there's a clear continuum of shapes we can interpret and support: (ndraws[, nchains[, nparams...]]). While bfmi can obviously only support the first 2 dimensions, the others can also support the vector case and trailing param dimensions, as might result from stacking chains of draws of matrix random variables.

Why do this? While a user can always reshape to a 3D array, this is not ideal for arrays with named dimensions/indices, since reshape causes all named dimensions to be lost. e.g.

julia> using DimensionalData, MCMCDiagnosticTools

julia> da1 = DimArray(randn(1000, 1, 1), (:draw, :chain, :param));

julia> ess(da1)  # ideal case, dimensions preserved
1-element DimArray{Float64,1} with dimensions: Dim{:param}
 1  960.129

julia> da2 = DimArray(randn(1000, 1), (:draw, :chain));

julia> ess(reshape(da2, size(da2)..., 1))  # named dimensions lost
1-element Vector{Float64}:
 1036.1139385722638

julia> da3 = DimArray(randn(1000), (:draw,));

julia> ess(reshape(da3, size(da3)..., 1, 1))  # named dimensions lost
1-element Vector{Float64}:
 929.3129270202691

julia> da4 = DimArray(randn(1000, 4, 3, 4), (:draw, :chain, :param1, :param2));

julia> ess(reshape(da4, size(da4,1), size(da4,2), :))  # named dimensions lost
12-element Vector{Float64}:
 3950.8133148219445
 3963.51341805499
 3887.7997316174083
 4058.9638959410036
 3685.782636881455
 3633.6820465330984
 3868.9859706770994
 3792.095865124389
 3921.8966121239305
 4111.7648208639375
 3561.7107571982556
 3508.2107768022665

In ArviZ, which uses DimensionalData.DimArrays to store draws, this requires quite a bit of boilerplate whenever we call one of these methods to reshape for MCMCDiagnosticTools, then unreshape the result, then add back dimensions. The proposed generalization is still unambiguous and allows the functions to be used more ergonomically in such cases.

sethaxen commented 1 year ago

@devmotion what do you think of this proposal?

sethaxen commented 1 year ago

Fixed by #79