Closed ColCarroll closed 6 years ago
PyStan fit object is a StanFit4Model cython class. The ordered dict comes from fit.extract, but there are otherways to interact with it.
edit. typos
@ahartikainen The idea would be for PyStan to write a converter from StanFir4Model
(or the ordered dict) to an xarray
that can then be passed to arviz.
Should the xarray
object also includes diagnostics such as rhat and effective sample size? The PyStan object has cython function that computes the rhat and effective sample size quite efficiently. Last time when I was updating the effective sample size implementation in pymc3 I found that the python implementation is much slower when the input array is large. @aseyboldt has a numba implementation which is much faster but PyMC3 does not have numba as dependence.
My proposal is that, we make a faster implementation of rhat and effective size implementation in arviz (the numba version), and a general converter that change pymc3 trace and StanFir4Model
into xarray
(could be two independent one). It has the trace as the first level, and diagnostics and stats as the second level. The second level is optional but will be computed by default if not provided.
Thoughts?
Good idea. I think it makes sense to allow inclusion of these stats in the xarray and if not present, compute it using a fast (numba) implementation in arviz.
@twiecki yes, that is a good idea. I'm currently writing function to transform fit object to suitable format for the current implementation (of plotting code).
But if Arviz will implement code to read specific type of xarray object we can wait for that before we release the arviz wrapped code.
@junpenglao It would be ideal to just give a fit object to arviz and let the arviz do the magic inside it.
Do you think the numba is going to be a dependency or optional?
I'm not sure computing rhat and other stats in arviz is a good idea. Some of that depends on the sampler (eg treedepth), and if we want to print warnings or a report of some kind after sampling then we need to compute it there anyway. I think it would be good if those stats could be in the xarray as well, but there is the problem of name collisions. If we only have rhat and n_effective_samples or so in the array, then that would be fine, but if we also want to put all sampler specific stats in it, then that list would suddenly be much longer and also change over time. An alternative would be to have two xarrays, one for the samples, and one for the stats.
Yes, summary stats and samples basically live in a different spaces.
Either two xarrays or some tricks with the indexing could work.
I don't see a problem making numba a dependency, it matured a lot and is well packaged. Eventually there is a high chance scipy will use it too.
Personally, I prefer to make it a dependency - the alternative implementation is too slow. I am hoping the numba installation is not an issue anymore - last time i check if you are not under conda there are some complications.
Also I come across this discussion on the mc-stan discourse: Proposal for consolidated output and the related Stan wiki by @martinmodrak and @sakrejda. I think this is a great discussion to have, specifically, whether there could be a universal representation of different inference, with their related diagnostics.
My idea would be:
level 0, meta information. This determint the lower structure Used inference (sampling, approximation, estimator) Avaliable diagnostics Parameterization of the approximation (VI or laplace approximation etc)
level 1, summary This would be a point from the parameter space of the posterior / likelihood function, with related error estimation for estimator and variance/covariance matrix for MCMC samples. This includes: MLE or MAP and their associate error VI parameters mean and cov of MCMC samples The complication is that, if you are doing some kind of VI approximation that is not parameterized by only mean and std/cov, those information need to save separately.
level 2, samples MCMC samples. For VI we can sample from the approximation model (we have that functionality quite handy in PyMC3). If estimator is used than it is just 1 sample
level 3, diagnoistics and statistics including divergence, tree depth, etc for HMC and NUTS effective sample size, rhat for other MCMC samplers ELBO history for VI
Maybe we should start a google doc and also invite others working on PPL to edit? For example the folks from tensorflow/probability, Pyro etc.
Good idea. CC @fritzo @dustinvtran ArviZ is a package that separates out PyMC3's plotting and some analysis functionality to create many commonly used plots like a traceplot. With PyStan we're currently discussing potential standardized storage objects (xarray). Is there any interest from Edward/Edward2/Pyro to collaborate on this?
cc for active (Py)Stan folks
@ariddell @seantalts @braaannigan
I have no objections to using xarrays. Sounds like a reasonable idea.
No objections, probably handy to agree on a naming convention for parameters before implementing in the various code bases
@twiecki Yes the Pyro team is interested in collaborating on standardized storage formats that can facilitate comparison and encourage inference algorithm research. cc @neerajprad @jpchen @rohitsingh who are working on HMC and Pyro-Stan compatibility.
cc @yebai @xukai92 you might be interested in this for Turing
Re: @junpenglao, we've explicitly for the moment punted on this question while we re-organize the intermediate layer to make it possible. I agree that it's important and has been a topic of ongoing discussion so if there is a broader discussion please let us know. Since Stan is multi-interface it might be more complicated at the file-format level (we've been talking about something streaming-friendly in ProtoBuf) but at the level of deciding how outputs should be grouped and organized it would be fantastic to have compatability with other projects.
As an xarray developer and probabilistic programming enthusiast, I'd really love to see this happen. Please feel free to ping me if you come across any issues.
@matthewdhoffman,@davmre,@jvdillon,@csuter,@derifatives,@axch,@srvasude
Edward2 and TFP's abstraction level doesn't really require named data structures except for dicts and namedtuple, which are more for collecting heterogenous data. For PyMC*, Stan, and others, xarray instead of custom classes or pd.DataFrame sounds like a great idea.
Wow awesome proposal guys!
If any of the pymc folks want to try it out in some projects, there is some code for getting pymc traces into xarray here: https://discourse.pymc.io/t/use-xarray-for-traces/73
For some time now, I've been wrapping pymc traces in a fit
object, that serialises to a netcdf file. The format for that file could be something like this:
/trace
: Stores the actual trace of a mcmc run. eg to_xarray(pymc_trace).to_netcdf('file', group='/trace')
/trace_stats
: This is where stats of the sampler (treedepth etc) can go. Also probably info about where divergences happend. Also rhat and effective_n etc.../data
: The observed variables and their values in the model (optional)/advi
: Some format for the advi result (not sure about that, could just be a trace as well, or some other representation of the approximation)/advi_stats
: Stats for the advi, eg history of elbo or so.And we can put some meta info in the attr tags as well. If we had a serialization format for the model itself, we could also add a /pymc_model
group and store it there. (@stan-folks /stan_model
:-) )
This should avoid name collisions between stats and variables, but it would probably mean that we have to duplicate some dimension labels, if we need them in more than one group (not sure how to get xarray to read dimensions if they are in different groups, but if that works we could just add /dims
.
On the topic of interoperability:
I think it would be great if we could get this to a point where different tools use the same format. But I think we also need to be careful not to promise too much interoperability. I can't see a reason why /trace
couldn't be the same for eg stan and pymc, but for /trace_stats this isn't as clear anymore. They store basically the same thing, but if we want to keep that interoperable that might hinder development. A attr that stores which program and version created those stats might be really helpful, and a visualisation lib could then just have specific code to read those if necessary.
Hey, @aseyboldt @ColCarroll
I have basically put stuff into xarray from PyStan fit object. This should work with PyStan 2.16 onwards (I updated our .extract
method). It can still be updated if needed for earlier versions .
I did split the data between the warmup and sampling.
def pars_to_xarray(fit, pars=None, infer_dtypes=True):
...
added regex magic to infer ints from the model code automatically
...
return data_set, data_set_warmup
def sampler_params_to_xarray(fit, params=None):
...
return sampler_params_dataset, sampler_params_dataset_warmup
def inits_to_xarray(fit, pars=None, infer_dtypes=True):
...
return inits_dataset
def summary_to_xarray(fit, pars=None):
...
transform summary data to dataframe --> xarray
...
return summary_dataset, c_summary_dataset
All the sampled parameters are in their "original" shape (same goes for the inits function)
shape = (draw, chain, *parameter_shape)
val, vec, mat -->
val, (draw, chain)
vec, (draw, chain, vec_axis1)
mat, (draw, chain, mat_axis1, mat_axis2)
All the sampler parameters are
accept_stat__, (draw, chain) float64
stepsize__, (draw, chain) float64
treedepth__, (draw, chain) int64
n_leapfrog__, (draw, chain) int64
divergent__, (draw, chain) bool
energy__, (draw, chain) float64
lp__, (draw, chain) float64
Summary xarray Dataset are in their flatname format
val, (index)
vec[1], (index)
vec[2], (index)
mat[1,1] (index)
mat[1,2] (index)
mat[2,1] (index)
mat[2,2] (index)
Summary for chain is
val, (chain, index)
vec[1], (chain, index)
vec[2], (chain, index)
mat[1,1] (chain, index)
mat[1,2] (chain, index)
mat[2,1] (chain, index)
mat[2,2] (chain, index)
So how should we parse this together? How about the naming?
I have been playing around with this today, too, using the non-centered eight schools model. See here for @aloctavodia's model code for both pystan
and pymc3
. I am just calling trace = pm.sample()
for pymc3
, and
sm = pystan.StanModel(model_code=schools_code)
fit = sm.sampling(data=schools_dat, iter=1000, chains=4)
for pystan
.
My biggest difficulty right now is automatically detecting that both theta
and theta_tilde
have the same first dimension (i.e., that each is a vector referring to the same 8 schools). Building off @aseyboldt's example notebooks, my API currently looks like this:
data = to_xarray(
non_centered_eight_trace,
coords = {
'school': np.arange(J)
},
dims={
'theta_tilde': ['school'],
'theta': ['school'],
}
)
The output from that looks like this:
<xarray.Dataset>
Dimensions: (chain: 4, sample: 500, school: 8)
Coordinates:
* school (school) int64 0 1 2 3 4 5 6 7
* sample (sample) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 ...
* chain (chain) int64 0 1 2 3
Data variables:
mu (chain, sample) float64 2.431 3.441 0.8163 2.117 ...
theta_tilde (chain, sample, school) float64 -0.03801 -0.2972 ...
tau (chain, sample) float64 10.98 2.865 5.888 0.3688 ...
theta (chain, sample, school) float64 2.014 -0.8323 ...
stat__max_energy_error (chain, sample) float64 1.172 0.9255 -0.1971 ...
stat__mean_tree_accept (chain, sample) float64 0.8326 0.5738 0.9696 ...
stat__step_size (chain, sample) float64 0.6574 0.6574 0.6574 ...
stat__tree_size (chain, sample) float64 7.0 7.0 7.0 7.0 7.0 7.0 ...
stat__energy (chain, sample) float64 48.17 49.51 49.38 55.18 ...
stat__tune (chain, sample) bool False False False False ...
stat__diverging (chain, sample) bool False False False False ...
stat__energy_error (chain, sample) float64 -0.04698 0.4614 -0.1971 ...
stat__depth (chain, sample) int64 3 3 3 3 3 3 3 3 3 3 3 3 3 ...
stat__step_size_bar (chain, sample) float64 0.5544 0.5544 0.5544 ...
Hrm, I am reading your and Adrian's post more carefully, and agree that it would be a better model to have multiple xarrays for the different parts of the summary.
Let me keep looking at this, but can you give an example of how your output looks on the non-centered 8 schools model?
That sounds good. @ahartikainen If you like to share that regex code I could put both extraction methods into a notebook and play a bit with it.
@ColCarroll I don't think we can autodetect if things are the same dimension. Just because the shape is the same doesn't mean it has the same dimension. That is why I like the explicit syntax in the model:
coords = {
'school': ['name1', 'name2', 'name3']
}
with pm.Model(coords=coords):
theta = pm.Whatever(dims='school') # or dims=('school', 'whatever')
One additional issue is that of the index ('chain', 'sample')
:
sample
or draw
?sample
(or draw
) using trace.isel(sample=100)
we get one value per chain. I think that is a bit counter intuitive. We could use draw
(or sample
) as an hierarchical index trace.stack(draw=('chain', 'sample'))
, so that trace.isel(draw=100)
gives us a single sample. I kind of like this, but I'm not sure if the additional complexity of using an hierarchical index in all traces is worth it.Just a quick notebook about how I think we could use the xarrays: https://gist.github.com/aseyboldt/99b8b3ba71d0d58a92264c3bf99bbbf9
It looks like altair might add support for xarray: https://github.com/altair-viz/altair/issues/891
I created a (sketch of a) design document for a netcdf file-format that stores traces. Ideally, I think both stan and pymc (and of course other tools) could write their traces in this fromat, and arviz could use a xarray representation of that for visualisation. The file format should contain all info needed to reproduce the run, and also to debug sampling trouble.
@ahartikainen Is that similar to what you have in mind? I don't have any problems with major changes to this, it is only meant as a starting point. Feel free to edit this as you like. https://yourpart.eu/p/SXfBlllfnl
@ColCarroll here is an example of the metacode above @aseyboldt the regex is in the first function
https://gist.github.com/ahartikainen/b16704eec3a912ccd3bb39d62ca04279
Samples / Draw, not sure what is the correct term. I think that one draw equals one value for each parameter in the model.
@aseyboldt that looks a good starting point.
@aseyboldt I worked off your example, and also added a function that gives an informative error message since it took me a little while to understand the syntax.
https://gist.github.com/ColCarroll/c607842947b08bc44d4e1588e6bef98d
@ahartikainen Is there a way to tell pars_to_xarray
that theta_axis1
and theta_tilde_axis1
are both referring to the same 8 schools? That is why I am using the slightly more cumbersome notation
data = to_xarray(non_centered_eight_trace,
coords={'school': np.arange(8)},
dims={'theta_tilde': ['school'], 'theta': ['school']}
)
Which means the resulting dimensions are just school
, sample
, and chain
.
Doing that automatically: probably not an easy task. Your way to define them looks good.
What is the array order that xarray uses. Which part are "contiguous"? Should that be reflected in the order of axes?
@ahartikainen I think that depends a lot on the backend. If you read data from netcdf4 (hdf5 internally), it autoselects some chunking (which we can override if we want). On numpy I think it uses numpy conventions, so usually (changes for transposition) c-continuous storage. From that point of view I think an order like (school, chain, sample)
should be the fastest, if you regularly look at all draws for one variable. I didn't test this though.
@ColCarroll Errors are great :-) (or rather error messages, I don't like errors)
Samples / Draw, not sure what is the correct term. I think that one draw equals one value for each parameter in the model.
Wikipedia https://en.wikipedia.org/wiki/Sample_(statistics) says: "In statistics and quantitative research methodology, a data sample is a set of data collected and/or selected from a statistical population by a defined procedure."
So in this case it would be natural that posterior sample is a set of posterior draws. This is what Stan team recommends, although it has not been strictly enforced and variation exists.
We try to match our terminology to Stan so draws is fine with me.
Wanted to post a quick update on this issue, since there was a lot of good discussion earlier:
-- There are utilities now for converting posteriors from PyStan and PyMC3 to xarray Datasets
-- Two out of twelve plots use xarray, and will mostly transparently work with posterior draws from either library (#111 has an example of comparing posterior draws from the eight schools model using both PyStan
and PyMC3
)
My view of next steps, which can go mostly in parallel:
-- Finish porting plots to use xarray
-- Port statistical tests to use xarray
-- Add a converter for OrderedDict
/dict
, which would work for Edward
, Edward2
, and the nascent
PyMC4`
After that (which isn't that much!), I think it would be reasonable to cut a release on pypi/conda-forge, and start working on using, for example, sampler statistics or observed data in some of these visualizations/analyses.
I appreciate any input/suggestions/help, as always!
Should we add also kwarg
for diagnostics (sampler parameters).
It could be "all" or nothing.
Also common names are easier if they are fixed (?)
How would this relate to developments around apache arrow
brings platform/language independent support for big data (cross C++, Python and JVM)?
See also https://github.com/QuantStack/xtensor/issues/394#issuecomment-330426802
It would be great to have a "common language for pymc3, pystan, and pymc4." that brings them closer to JVM.
@SemanticBeeng That's definitely where this is headed. I wouldn't compare it to JVM but rather a common format to store model results that's standardized across different PPLs.
Indeed not to JVM specifically but a platform independent format - so is JVM included?
If relevant to you, curious to know how the intent in this thread:
apache arrow
data schema
across languages (asking because for JVM people types are very important)
See for context :
I understand that if JVM is not in the picture then the above are not applicable to this context.
Is this now use InferenceData all throughout?
The library is now using xarray and netcdf throughout.
There have been proposals to use xarray as a common language for pymc3, pystan, and pymc4. This library might be a good place to start that by implementing utility functions for translating pymc3's Multitrace and pystan's OrderedDict into xarray objects, and then having all plotting functions work with xarrays.