arviz-devs / arviz

Exploratory analysis of Bayesian models with Python
https://python.arviz.org
Apache License 2.0
1.61k stars 407 forks source link

Ensure dimension order can be anything #1693

Open sethaxen opened 3 years ago

sethaxen commented 3 years ago

Currently all examples I have seen order the dimensions for samples as (chain, draw, shape...), where shape is the shape of a single draw of the variable. This particular ordering shouldn't be necessary, and allowing alternative orderings might enable samples from PPLs to be used without copying. e.g. Turing stores variables as (draw, shape..., chain).

This might already be supported. If so, it should be tested.

junpenglao commented 3 years ago

+1. TFP mcmc currently output variables as (draw, chain, shape, ...) I think an option to pass an indication or mapping of where the draw and chain is would be sufficient

OriolAbril commented 3 years ago

There are two things to take into account for this to work, luckily with xarray the second is already taken care of.

The first thing is conversion from ppl data to inferencedata. ppls use a given ordering but so far none uses named dimensions, so we need to know the ordering for things to work. Currently all converters end up calling dict_to_dataset which assumes the input are dicts with var names as keys and arrays as values. Those arrays currently need to be chain, draw, *shape which has the great advantage of being able to do things like az.from_dict({"a": np.random.normal(size=(4, 100, 8))}, dims={"a": ["team"]}) and it works because it knows that if chain and draw are not in dims they need to be prepended.

I think it would also be great to allow something like

az.from_dict({"a": np.random.normal(size=(100, 8, 4))}, dims={"a": ["draw", "team", "chain"]})

And have it work, which will need some work in the base converters but should be possible. Once this is done, we can also look into updating all the converters that currently use swapdims, as it may (or may not) be easier to handle things this way.

The second thing is that once data is already as inferencedata, ArviZ needs to always used named dims, so it doesn't actually care about the order of the dimensions anymore, only about their names (like xarray does). I am 90% sure that all ArviZ code does that already but we need to check. Here are some examples showing how this already works:

import arviz as az
idata = az.load_arviz_data("rugby")
ds = idata.posterior[["atts", "defs"]]
ds = ds.transpose("draw", "team", "chain")
az.summary(ds)
          mean     sd  hdi_3%  hdi_97%  ...  mcse_sd  ess_bulk  ess_tail  r_hat
atts[0]  0.171  0.042   0.095    0.251  ...    0.001    2574.0    1409.0    1.0
atts[1] -0.083  0.044  -0.164   -0.001  ...    0.001    2330.0    1680.0    1.0
atts[2]  0.108  0.041   0.036    0.185  ...    0.001    2652.0    1250.0    1.0
atts[3] -0.116  0.048  -0.203   -0.026  ...    0.001    2156.0    1483.0    1.0
atts[4] -0.337  0.057  -0.449   -0.238  ...    0.001    2333.0    1350.0    1.0
atts[5]  0.257  0.042   0.179    0.338  ...    0.001    2418.0    1401.0    1.0
defs[0] -0.129  0.051  -0.231   -0.038  ...    0.001    2227.0    1449.0    1.0
defs[1] -0.041  0.047  -0.127    0.046  ...    0.001    2379.0    1779.0    1.0
defs[2] -0.388  0.053  -0.490   -0.292  ...    0.001    2407.0    1647.0    1.0
defs[3]  0.173  0.043   0.086    0.249  ...    0.001    2243.0    1736.0    1.0
defs[4]  0.583  0.036   0.517    0.651  ...    0.001    2151.0    1672.0    1.0
defs[5] -0.196  0.051  -0.296   -0.103  ...    0.001    2170.0    1518.0    1.0

[12 rows x 9 columns]
az.plot_forest(ds)

image

junpenglao commented 3 years ago

Expanding on the first case, how about supporting an API like:

az.from_dict(
    {"a": np.random.normal(size=(100, 8, 4)),
     "b": np.random.normal(size=(100, 4)),
    }, 
    draw_axis=0,
    chain_axis=-1,
)
az.from_dict(
    {"a": np.random.normal(size=(100, 4, 8)),
     "b": np.random.normal(size=(100, 4)),
    }, 
    draw_axis=0,
    chain_axis=1,
)

which default being draw_axis=1, chain_axis=0

junpenglao commented 3 years ago

Oh it also makes supporting single chain and MAP easier:

# Single chain
az.from_dict(
    {"a": np.random.normal(size=(100, 8)),
     "b": np.random.normal(size=(100)),
    }, 
    draw_axis=0,
    chain_axis=None,
)
# MAP
az.from_dict(
    {"a": np.random.normal(size=(8)),
     "b": np.random.normal(size=()),
    }, 
    draw_axis=None,
    chain_axis=None,
)
# multiple run of MAP (e.g., MAP sovled from different starting position)
az.from_dict(
    {"a": np.random.normal(size=(4, 8)),
     "b": np.random.normal(size=(4)),
    }, 
    draw_axis=None,
    chain_axis=0,
)