TuringLang / MCMCDiagnosticTools.jl

https://turinglang.org/MCMCDiagnosticTools.jl/dev
Other
19 stars 6 forks source link

Support different dimension permutations #5

Open sethaxen opened 3 years ago

sethaxen commented 3 years ago

The functions currently assume draws are in a single array with shape (draws, chains, params), like MCMCChains.Chains stores. We should consider how to support different permutations (could be as simple as recommending users use PermuteDimsArrays). e.g. ArviZ defaults to (chains, draws, params). Not certain about Soss's SampleChains.

devmotion commented 3 years ago

The functions currently assume draws are in a single array with shape (draws, chains, params)

They either work with vectors of draws or arrays of shape (draws, params, chains).

sethaxen commented 3 years ago

How would one pass a vector of draws to e.g. ess_rhat?

devmotion commented 3 years ago

No, it's not supported by ess_rhat. All functions either work with vectors or arrays of shape (draws, chains, params), but not both.

ParadaCarleton commented 3 years ago

The functions currently assume draws are in a single array with shape (draws, chains, params), like MCMCChains.Chains stores. We should consider how to support different permutations (could be as simple as recommending users use PermuteDimsArrays). e.g. ArviZ defaults to (chains, draws, params). Not certain about Soss's SampleChains.

And I think Stan defaults to (params, chains, draws), giving us a head start on our apparent task of going through every possible permutation of indices.

devmotion commented 3 years ago

I don't think this package should support different permutations. IMO we need a clearly documented and consistent choice for such situations but then users have to permute the arrays if their data is in a different format.

And just to reiterate, not all statistics work with 3d arrays of samples. Some work just with a vector of scalar-valued samples (one parameter, one chain) and I don't think this should be changed. Also the Rstar statistic works with a matrix of samples of size (draws, params) and a corresponding vector of chain indices (this is more general than a 3d array).

devmotion commented 3 years ago

I should add that I don't think we have to stick with the current convention but it would be good to be consistent for 3d arrays. I guess the choice should be motivated by what is the most convenient and efficient layout. Since Julia uses column-major order and Python row-major, probably it differs from what one would choose in Python.

ParadaCarleton commented 3 years ago

I should add that I don't think we have to stick with the current convention but it would be good to be consistent for 3d arrays. I guess the choice should be motivated by what is the most convenient and efficient layout. Since Julia uses column-major order and Python row-major, probably it differs from what one would choose in Python.

I mean, being able to work with different permutations of indices comes as a free side effect of using something like AxisKeys.jl, which I think we should be using anyways to avoid bugs from messing up the order we're indexing in. We can always provide a free "Fallback" that assumes dimensions are ordered in some clearly specified way.

devmotion commented 3 years ago

Please no, let's just stick with generic AbstractArrays - one main motivation here is to get rid of the Chains/AxisArrays mess and just be as generic as possible.

ParadaCarleton commented 3 years ago

I think the best arrangement would be (parameters, draws, chains). Operations on chains are usually done in parallel, e.g. sampling a different chain on each core, so it's not necessary to have chains located close to each other in memory. Parameters are sampled together, so they should be located close in memory so that every time one parameter is written to memory, the next parameter can be written to memory pretty easily. Iterations should be located somewhat close to each other, but aren't always accessed together the same way that parameters from a single iteration usually are. Please correct me if I'm wrong, I'm not a computer scientist.

ParadaCarleton commented 3 years ago

No, it's not supported by ess_rhat. All functions either work with vectors or arrays of shape (draws, chains, params), but not both.

I believe ess_rhat currently uses (draws, params, chains). (An arrangement I find extremely counterintuitive, since chains and draws aren't together.)

devmotion commented 3 years ago

No, it's not supported by ess_rhat. All functions either work with vectors or arrays of shape (draws, chains, params), but not both.

I believe ess_rhat currently uses (draws, params, chains). (An arrangement I find extremely counterintuitive, since chains and draws aren't together.)

Ah yes, this is what I wanted to write and what is mentioned in the documentation. The reason for this layout is that it is the one used in MCMCChains.Chains. However, I don't know what motivated the design choice there, this was decided before I got involved in MCMCChains. Maybe @cpfiffer knows why it was preferred over (parameters, draws, chains)?

cpfiffer commented 3 years ago

I think this was just a holdover from Mamba -- it was really something I hadn't considered.

sethaxen commented 3 years ago

There are two different ways to think about this. One is to reason about what permutation users are likely to pass (which leads to reasoning about what a sensible ordering in memory would be). However, a given PPL may not deliver that ordering (e.g. MCMCChains and IIRC SampleChains). The other is to think about the permutation that would be most efficient for a given function. However, different functions might prefer different permutations, and we should be consistent.

Ideally these two would converge, but I don't know if they do.

ParadaCarleton commented 3 years ago

I think any differences in terms of actual speed/efficiency are probably pretty small -- writing MCMC samples to memory is not going to be the bottleneck for something like HMC under any reasonable set of circumstances. Because of that, I think choices of index orderings should probably be based on what users are most likely to consider intuitive/reasonable orderings. In that case, I think either of (params, draws, chains) or (chains, draws, params) are the most intuitive, since they involve going from more general to more specific, or more specific to more general. (Multiple parameters are contained in a single draw, and several draws are contained in a single chain.) The former has the advantage that it's easier to leave off indices for chains when users are only sampling from a single chain, so I propose we go with that unless anyone has any strong objections.

ParadaCarleton commented 3 years ago

Does anybody object, or should I make a pull request reordering axes like this? I've written some PSIS code that assumes this ordering for ess_rhat, and would like to know whether I should rearrange the PSIS code or the ordering for ess_rhat.

ParadaCarleton commented 3 years ago

@sethaxen @devmotion I can create a PR reordering these axes, and implementing a consistent interface that works with arrays following this standard.

In cases where the interface isn't consistent, there's usually a fix that will make the function easier to work with from a user perspective. For instance, if a function only accepts one chain, then calling it on an array of multiple chains should return a vector with the results of applying it to each individual chain. (We can, of course, keep the original function for users who want to only call it on a single chain.) The goal should be to provide a polished, ArviZ-like interface that "Just works," and lets the user pass a single object to each function, rather than having to figure out how they need to permute, cut up, or use eachslice on their arrays to get the diagnostic they're looking for. This has already been done for a handful of functions, but not all of them.

devmotion commented 3 years ago

In my opinion, we should not "unify" functions that operate on single chains by moving them to an interface that works on 3d arrays of multiple chains with multiple parameters. It just makes it more complicated for downstream packages that use a different layout to use these diagnostics (e.g. vectors of chains of possibly different length, chains based on StructArrays etc. - in particular the StructArray approach is a longstanding issue/idea that we want to explore as an alternative to MCMCChains.Chains) and I don't see any immediate advantages even for the 3d case - you can always apply the function on the different slices of the array and e.g. even in this case sometimes you might want to pool different chains.

In general, I think any changes of the input layouts or dimensions should be left for a 0.2 or even more distant release but not included in the initial 0.1 version since otherwise it will just be more difficult to replace the diagnostics in MCMCChains with this package and the transition will be less clean (it would require many additional changes in https://github.com/TuringLang/MCMCChains.jl/pull/310). So unfortunately currently I would not approve any such PR that changes the input layout of any diagnostics but I think we should consider if we want to change the default permutation of 3d arrays at a later stage.

ParadaCarleton commented 3 years ago

Sure, we can handle this later if you want.

As for the other thing, I'm not suggesting that we replace the existing methods or get rid of them, just that we provide additional methods that handle everything for users who use the "default" arrangement, which I expect would be a three-dimensional array. Not every method needs to have more complicated, but every method should accept a 3d array as input (assuming it makes any kind of sense for it to accept that array). If someone wants to work with a matrix and a vector of chain indices, they can use the methods we already have without being bothered by the fact that we have another one that works on arrays. On the other hand, users who already have their data stored in an array shouldn't have to spend even more time cleaning their data than they already do. Figuring out how to e.g. disassemble an array and convert it into a matrix representation with a bunch of chain indices is going to be a pretty big waste of time for users; why not just have it work out of the box for them?