Open sethaxen opened 3 years ago
+1. TFP mcmc currently output variables as (draw, chain, shape, ...)
I think an option to pass an indication or mapping of where the draw
and chain
is would be sufficient
There are two things to take into account for this to work, luckily with xarray the second is already taken care of.
The first thing is conversion from ppl data to inferencedata. ppls use a given ordering but so far none uses named dimensions, so we need to know the ordering for things to work. Currently all converters end up calling dict_to_dataset
which assumes the input are dicts with var names as keys and arrays as values. Those arrays currently need to be chain, draw, *shape
which has the great advantage of being able to do things like az.from_dict({"a": np.random.normal(size=(4, 100, 8))}, dims={"a": ["team"]})
and it works because it knows that if chain and draw are not in dims they need to be prepended.
I think it would also be great to allow something like
az.from_dict({"a": np.random.normal(size=(100, 8, 4))}, dims={"a": ["draw", "team", "chain"]})
And have it work, which will need some work in the base converters but should be possible. Once this is done, we can also look into updating all the converters that currently use swapdims, as it may (or may not) be easier to handle things this way.
The second thing is that once data is already as inferencedata, ArviZ needs to always used named dims, so it doesn't actually care about the order of the dimensions anymore, only about their names (like xarray does). I am 90% sure that all ArviZ code does that already but we need to check. Here are some examples showing how this already works:
import arviz as az
idata = az.load_arviz_data("rugby")
ds = idata.posterior[["atts", "defs"]]
ds = ds.transpose("draw", "team", "chain")
az.summary(ds)
mean sd hdi_3% hdi_97% ... mcse_sd ess_bulk ess_tail r_hat
atts[0] 0.171 0.042 0.095 0.251 ... 0.001 2574.0 1409.0 1.0
atts[1] -0.083 0.044 -0.164 -0.001 ... 0.001 2330.0 1680.0 1.0
atts[2] 0.108 0.041 0.036 0.185 ... 0.001 2652.0 1250.0 1.0
atts[3] -0.116 0.048 -0.203 -0.026 ... 0.001 2156.0 1483.0 1.0
atts[4] -0.337 0.057 -0.449 -0.238 ... 0.001 2333.0 1350.0 1.0
atts[5] 0.257 0.042 0.179 0.338 ... 0.001 2418.0 1401.0 1.0
defs[0] -0.129 0.051 -0.231 -0.038 ... 0.001 2227.0 1449.0 1.0
defs[1] -0.041 0.047 -0.127 0.046 ... 0.001 2379.0 1779.0 1.0
defs[2] -0.388 0.053 -0.490 -0.292 ... 0.001 2407.0 1647.0 1.0
defs[3] 0.173 0.043 0.086 0.249 ... 0.001 2243.0 1736.0 1.0
defs[4] 0.583 0.036 0.517 0.651 ... 0.001 2151.0 1672.0 1.0
defs[5] -0.196 0.051 -0.296 -0.103 ... 0.001 2170.0 1518.0 1.0
[12 rows x 9 columns]
az.plot_forest(ds)
Expanding on the first case, how about supporting an API like:
az.from_dict(
{"a": np.random.normal(size=(100, 8, 4)),
"b": np.random.normal(size=(100, 4)),
},
draw_axis=0,
chain_axis=-1,
)
az.from_dict(
{"a": np.random.normal(size=(100, 4, 8)),
"b": np.random.normal(size=(100, 4)),
},
draw_axis=0,
chain_axis=1,
)
which default being draw_axis=1
, chain_axis=0
Oh it also makes supporting single chain and MAP easier:
# Single chain
az.from_dict(
{"a": np.random.normal(size=(100, 8)),
"b": np.random.normal(size=(100)),
},
draw_axis=0,
chain_axis=None,
)
# MAP
az.from_dict(
{"a": np.random.normal(size=(8)),
"b": np.random.normal(size=()),
},
draw_axis=None,
chain_axis=None,
)
# multiple run of MAP (e.g., MAP sovled from different starting position)
az.from_dict(
{"a": np.random.normal(size=(4, 8)),
"b": np.random.normal(size=(4)),
},
draw_axis=None,
chain_axis=0,
)
Currently all examples I have seen order the dimensions for samples as
(chain, draw, shape...)
, whereshape
is the shape of a single draw of the variable. This particular ordering shouldn't be necessary, and allowing alternative orderings might enable samples from PPLs to be used without copying. e.g. Turing stores variables as(draw, shape..., chain)
.This might already be supported. If so, it should be tested.