arviz-devs / arviz

Exploratory analysis of Bayesian models with Python
https://python.arviz.org
Apache License 2.0
1.59k stars 394 forks source link

Use xarray throughout #97

Closed ColCarroll closed 6 years ago

ColCarroll commented 6 years ago

There have been proposals to use xarray as a common language for pymc3, pystan, and pymc4. This library might be a good place to start that by implementing utility functions for translating pymc3's Multitrace and pystan's OrderedDict into xarray objects, and then having all plotting functions work with xarrays.

ahartikainen commented 6 years ago

PyStan fit object is a StanFit4Model cython class. The ordered dict comes from fit.extract, but there are otherways to interact with it.

edit. typos

twiecki commented 6 years ago

@ahartikainen The idea would be for PyStan to write a converter from StanFir4Model (or the ordered dict) to an xarray that can then be passed to arviz.

junpenglao commented 6 years ago

Should the xarray object also includes diagnostics such as rhat and effective sample size? The PyStan object has cython function that computes the rhat and effective sample size quite efficiently. Last time when I was updating the effective sample size implementation in pymc3 I found that the python implementation is much slower when the input array is large. @aseyboldt has a numba implementation which is much faster but PyMC3 does not have numba as dependence.

My proposal is that, we make a faster implementation of rhat and effective size implementation in arviz (the numba version), and a general converter that change pymc3 trace and StanFir4Model into xarray (could be two independent one). It has the trace as the first level, and diagnostics and stats as the second level. The second level is optional but will be computed by default if not provided.

Thoughts?

twiecki commented 6 years ago

Good idea. I think it makes sense to allow inclusion of these stats in the xarray and if not present, compute it using a fast (numba) implementation in arviz.

ahartikainen commented 6 years ago

@twiecki yes, that is a good idea. I'm currently writing function to transform fit object to suitable format for the current implementation (of plotting code).

But if Arviz will implement code to read specific type of xarray object we can wait for that before we release the arviz wrapped code.

@junpenglao It would be ideal to just give a fit object to arviz and let the arviz do the magic inside it.

Do you think the numba is going to be a dependency or optional?

aseyboldt commented 6 years ago

I'm not sure computing rhat and other stats in arviz is a good idea. Some of that depends on the sampler (eg treedepth), and if we want to print warnings or a report of some kind after sampling then we need to compute it there anyway. I think it would be good if those stats could be in the xarray as well, but there is the problem of name collisions. If we only have rhat and n_effective_samples or so in the array, then that would be fine, but if we also want to put all sampler specific stats in it, then that list would suddenly be much longer and also change over time. An alternative would be to have two xarrays, one for the samples, and one for the stats.

ahartikainen commented 6 years ago

Yes, summary stats and samples basically live in a different spaces.

Either two xarrays or some tricks with the indexing could work.

twiecki commented 6 years ago

I don't see a problem making numba a dependency, it matured a lot and is well packaged. Eventually there is a high chance scipy will use it too.

junpenglao commented 6 years ago

Personally, I prefer to make it a dependency - the alternative implementation is too slow. I am hoping the numba installation is not an issue anymore - last time i check if you are not under conda there are some complications.

Also I come across this discussion on the mc-stan discourse: Proposal for consolidated output and the related Stan wiki by @martinmodrak and @sakrejda. I think this is a great discussion to have, specifically, whether there could be a universal representation of different inference, with their related diagnostics.

My idea would be:

junpenglao commented 6 years ago

Maybe we should start a google doc and also invite others working on PPL to edit? For example the folks from tensorflow/probability, Pyro etc.

twiecki commented 6 years ago

Good idea. CC @fritzo @dustinvtran ArviZ is a package that separates out PyMC3's plotting and some analysis functionality to create many commonly used plots like a traceplot. With PyStan we're currently discussing potential standardized storage objects (xarray). Is there any interest from Edward/Edward2/Pyro to collaborate on this?

ahartikainen commented 6 years ago

cc for active (Py)Stan folks

@ariddell @seantalts @braaannigan

ariddell commented 6 years ago

I have no objections to using xarrays. Sounds like a reasonable idea.

braaannigan commented 6 years ago

No objections, probably handy to agree on a naming convention for parameters before implementing in the various code bases

fritzo commented 6 years ago

@twiecki Yes the Pyro team is interested in collaborating on standardized storage formats that can facilitate comparison and encourage inference algorithm research. cc @neerajprad @jpchen @rohitsingh who are working on HMC and Pyro-Stan compatibility.

eb8680 commented 6 years ago

cc @yebai @xukai92 you might be interested in this for Turing

sakrejda commented 6 years ago

Re: @junpenglao, we've explicitly for the moment punted on this question while we re-organize the intermediate layer to make it possible. I agree that it's important and has been a topic of ongoing discussion so if there is a broader discussion please let us know. Since Stan is multi-interface it might be more complicated at the file-format level (we've been talking about something streaming-friendly in ProtoBuf) but at the level of deciding how outputs should be grouped and organized it would be fantastic to have compatability with other projects.

shoyer commented 6 years ago

As an xarray developer and probabilistic programming enthusiast, I'd really love to see this happen. Please feel free to ping me if you come across any issues.

dustinvtran commented 6 years ago

@matthewdhoffman,@davmre,@jvdillon,@csuter,@derifatives,@axch,@srvasude

Edward2 and TFP's abstraction level doesn't really require named data structures except for dicts and namedtuple, which are more for collecting heterogenous data. For PyMC*, Stan, and others, xarray instead of custom classes or pd.DataFrame sounds like a great idea.

springcoil commented 6 years ago

Wow awesome proposal guys!

aseyboldt commented 6 years ago

If any of the pymc folks want to try it out in some projects, there is some code for getting pymc traces into xarray here: https://discourse.pymc.io/t/use-xarray-for-traces/73 For some time now, I've been wrapping pymc traces in a fit object, that serialises to a netcdf file. The format for that file could be something like this:

And we can put some meta info in the attr tags as well. If we had a serialization format for the model itself, we could also add a /pymc_model group and store it there. (@stan-folks /stan_model :-) )

This should avoid name collisions between stats and variables, but it would probably mean that we have to duplicate some dimension labels, if we need them in more than one group (not sure how to get xarray to read dimensions if they are in different groups, but if that works we could just add /dims.

aseyboldt commented 6 years ago

On the topic of interoperability: I think it would be great if we could get this to a point where different tools use the same format. But I think we also need to be careful not to promise too much interoperability. I can't see a reason why /trace couldn't be the same for eg stan and pymc, but for /trace_stats this isn't as clear anymore. They store basically the same thing, but if we want to keep that interoperable that might hinder development. A attr that stores which program and version created those stats might be really helpful, and a visualisation lib could then just have specific code to read those if necessary.

ahartikainen commented 6 years ago

Hey, @aseyboldt @ColCarroll

I have basically put stuff into xarray from PyStan fit object. This should work with PyStan 2.16 onwards (I updated our .extract method). It can still be updated if needed for earlier versions . I did split the data between the warmup and sampling.

def pars_to_xarray(fit, pars=None, infer_dtypes=True):
    ...
    added regex magic to infer ints from the model code automatically
    ...
    return data_set, data_set_warmup

def sampler_params_to_xarray(fit, params=None):
    ...
    return sampler_params_dataset, sampler_params_dataset_warmup

def inits_to_xarray(fit, pars=None, infer_dtypes=True):
    ...
    return inits_dataset

def summary_to_xarray(fit, pars=None):
    ...
    transform summary data to dataframe --> xarray
    ...
    return summary_dataset, c_summary_dataset 

All the sampled parameters are in their "original" shape (same goes for the inits function)

shape = (draw, chain, *parameter_shape)
val, vec, mat -->
val, (draw, chain)
vec, (draw, chain, vec_axis1)
mat, (draw, chain, mat_axis1, mat_axis2)

All the sampler parameters are

accept_stat__, (draw, chain) float64
stepsize__, (draw, chain) float64
treedepth__, (draw, chain) int64
n_leapfrog__, (draw, chain) int64
divergent__, (draw, chain) bool
energy__, (draw, chain) float64
lp__, (draw, chain) float64

Summary xarray Dataset are in their flatname format

val, (index)
vec[1], (index)
vec[2], (index)
mat[1,1] (index)
mat[1,2] (index)
mat[2,1] (index)
mat[2,2] (index)

Summary for chain is

val, (chain, index)
vec[1], (chain, index)
vec[2], (chain, index)
mat[1,1] (chain, index)
mat[1,2] (chain, index)
mat[2,1] (chain, index)
mat[2,2] (chain, index)

So how should we parse this together? How about the naming?

ColCarroll commented 6 years ago

I have been playing around with this today, too, using the non-centered eight schools model. See here for @aloctavodia's model code for both pystan and pymc3. I am just calling trace = pm.sample() for pymc3, and

sm = pystan.StanModel(model_code=schools_code)
fit = sm.sampling(data=schools_dat, iter=1000, chains=4)

for pystan.

My biggest difficulty right now is automatically detecting that both theta and theta_tilde have the same first dimension (i.e., that each is a vector referring to the same 8 schools). Building off @aseyboldt's example notebooks, my API currently looks like this:

data = to_xarray(
    non_centered_eight_trace, 
    coords = {
        'school': np.arange(J)
    }, 
    dims={
        'theta_tilde': ['school'], 
        'theta': ['school'], 
    }
)

The output from that looks like this:

<xarray.Dataset>
Dimensions:                 (chain: 4, sample: 500, school: 8)
Coordinates:
  * school                  (school) int64 0 1 2 3 4 5 6 7
  * sample                  (sample) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 ...
  * chain                   (chain) int64 0 1 2 3
Data variables:
    mu                      (chain, sample) float64 2.431 3.441 0.8163 2.117 ...
    theta_tilde             (chain, sample, school) float64 -0.03801 -0.2972 ...
    tau                     (chain, sample) float64 10.98 2.865 5.888 0.3688 ...
    theta                   (chain, sample, school) float64 2.014 -0.8323 ...
    stat__max_energy_error  (chain, sample) float64 1.172 0.9255 -0.1971 ...
    stat__mean_tree_accept  (chain, sample) float64 0.8326 0.5738 0.9696 ...
    stat__step_size         (chain, sample) float64 0.6574 0.6574 0.6574 ...
    stat__tree_size         (chain, sample) float64 7.0 7.0 7.0 7.0 7.0 7.0 ...
    stat__energy            (chain, sample) float64 48.17 49.51 49.38 55.18 ...
    stat__tune              (chain, sample) bool False False False False ...
    stat__diverging         (chain, sample) bool False False False False ...
    stat__energy_error      (chain, sample) float64 -0.04698 0.4614 -0.1971 ...
    stat__depth             (chain, sample) int64 3 3 3 3 3 3 3 3 3 3 3 3 3 ...
    stat__step_size_bar     (chain, sample) float64 0.5544 0.5544 0.5544 ...
ColCarroll commented 6 years ago

Hrm, I am reading your and Adrian's post more carefully, and agree that it would be a better model to have multiple xarrays for the different parts of the summary.

Let me keep looking at this, but can you give an example of how your output looks on the non-centered 8 schools model?

aseyboldt commented 6 years ago

That sounds good. @ahartikainen If you like to share that regex code I could put both extraction methods into a notebook and play a bit with it.

@ColCarroll I don't think we can autodetect if things are the same dimension. Just because the shape is the same doesn't mean it has the same dimension. That is why I like the explicit syntax in the model:

coords = {
    'school': ['name1', 'name2', 'name3']
}

with pm.Model(coords=coords):
    theta = pm.Whatever(dims='school')  # or dims=('school', 'whatever')

One additional issue is that of the index ('chain', 'sample'):

aseyboldt commented 6 years ago

Just a quick notebook about how I think we could use the xarrays: https://gist.github.com/aseyboldt/99b8b3ba71d0d58a92264c3bf99bbbf9

aseyboldt commented 6 years ago

It looks like altair might add support for xarray: https://github.com/altair-viz/altair/issues/891

aseyboldt commented 6 years ago

I created a (sketch of a) design document for a netcdf file-format that stores traces. Ideally, I think both stan and pymc (and of course other tools) could write their traces in this fromat, and arviz could use a xarray representation of that for visualisation. The file format should contain all info needed to reproduce the run, and also to debug sampling trouble.

@ahartikainen Is that similar to what you have in mind? I don't have any problems with major changes to this, it is only meant as a starting point. Feel free to edit this as you like. https://yourpart.eu/p/SXfBlllfnl

ahartikainen commented 6 years ago

@ColCarroll here is an example of the metacode above @aseyboldt the regex is in the first function

https://gist.github.com/ahartikainen/b16704eec3a912ccd3bb39d62ca04279

Samples / Draw, not sure what is the correct term. I think that one draw equals one value for each parameter in the model.

@aseyboldt that looks a good starting point.

ColCarroll commented 6 years ago

@aseyboldt I worked off your example, and also added a function that gives an informative error message since it took me a little while to understand the syntax.

https://gist.github.com/ColCarroll/c607842947b08bc44d4e1588e6bef98d

@ahartikainen Is there a way to tell pars_to_xarray that theta_axis1 and theta_tilde_axis1 are both referring to the same 8 schools? That is why I am using the slightly more cumbersome notation

data = to_xarray(non_centered_eight_trace, 
                 coords={'school': np.arange(8)}, 
                 dims={'theta_tilde': ['school'], 'theta': ['school']} 
)

Which means the resulting dimensions are just school, sample, and chain.

ahartikainen commented 6 years ago

Doing that automatically: probably not an easy task. Your way to define them looks good.

What is the array order that xarray uses. Which part are "contiguous"? Should that be reflected in the order of axes?

aseyboldt commented 6 years ago

@ahartikainen I think that depends a lot on the backend. If you read data from netcdf4 (hdf5 internally), it autoselects some chunking (which we can override if we want). On numpy I think it uses numpy conventions, so usually (changes for transposition) c-continuous storage. From that point of view I think an order like (school, chain, sample) should be the fastest, if you regularly look at all draws for one variable. I didn't test this though.

@ColCarroll Errors are great :-) (or rather error messages, I don't like errors)

avehtari commented 6 years ago

Samples / Draw, not sure what is the correct term. I think that one draw equals one value for each parameter in the model.

Wikipedia https://en.wikipedia.org/wiki/Sample_(statistics) says: "In statistics and quantitative research methodology, a data sample is a set of data collected and/or selected from a statistical population by a defined procedure."

So in this case it would be natural that posterior sample is a set of posterior draws. This is what Stan team recommends, although it has not been strictly enforced and variation exists.

twiecki commented 6 years ago

We try to match our terminology to Stan so draws is fine with me.

ColCarroll commented 6 years ago

Wanted to post a quick update on this issue, since there was a lot of good discussion earlier:

-- There are utilities now for converting posteriors from PyStan and PyMC3 to xarray Datasets -- Two out of twelve plots use xarray, and will mostly transparently work with posterior draws from either library (#111 has an example of comparing posterior draws from the eight schools model using both PyStan and PyMC3)

My view of next steps, which can go mostly in parallel: -- Finish porting plots to use xarray -- Port statistical tests to use xarray -- Add a converter for OrderedDict/dict, which would work for Edward, Edward2, and the nascentPyMC4`

After that (which isn't that much!), I think it would be reasonable to cut a release on pypi/conda-forge, and start working on using, for example, sampler statistics or observed data in some of these visualizations/analyses.

I appreciate any input/suggestions/help, as always!

ahartikainen commented 6 years ago

Should we add also kwarg for diagnostics (sampler parameters).

It could be "all" or nothing.

Also common names are easier if they are fixed (?)

SemanticBeeng commented 6 years ago

How would this relate to developments around apache arrow brings platform/language independent support for big data (cross C++, Python and JVM)?

See also https://github.com/QuantStack/xtensor/issues/394#issuecomment-330426802

It would be great to have a "common language for pymc3, pystan, and pymc4." that brings them closer to JVM.

twiecki commented 6 years ago

@SemanticBeeng That's definitely where this is headed. I wouldn't compare it to JVM but rather a common format to store model results that's standardized across different PPLs.

SemanticBeeng commented 6 years ago

Indeed not to JVM specifically but a platform independent format - so is JVM included?

If relevant to you, curious to know how the intent in this thread:

  1. compares to apache arrow
  2. relates to JVM interop
  3. plans to address the need to manage data schema across languages (asking because for JVM people types are very important) See for context :

I understand that if JVM is not in the picture then the above are not applicable to this context.

canyon289 commented 6 years ago

Is this now use InferenceData all throughout?

ColCarroll commented 6 years ago

The library is now using xarray and netcdf throughout.