rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
324 stars 17 forks source link

Integration with ArviZ #35

Closed rlouf closed 4 years ago

rlouf commented 4 years ago

ArviZ is the best package for the exploratory analysis of models, it makes sense to provide a seamless integration with the library. I chose to implement the Trace object as a subclass of Arviz's InferenceData so MCX traces can be directly used in ArviZ. There is a loss of information in that translation since the sampler returns pretty much all there is to know about the process, but we can get that information (say for debugging) using iterative sampling.

OriolAbril commented 4 years ago

I'll try to keep an eye and help wherever possible, I should have time to review and answer questions about InferenceData.

rlouf commented 4 years ago

Thanks! I'm doing a lot of code reading and trial/error at the moment, I'll ping you when I have something tangible!

rlouf commented 4 years ago

@OriolAbril Several of the metrics in sample_stats are noticeably slowing down inference for large sample sizes. Some I can probably improve (I have to transpose a couple of arrays to fit the format), but for others like log_probability I need to perform more calculations.

It can be problematic when users are first iterating in a model and only need the plot_trace. I can also think of production applications where you would only need the trace and performance is critical. Why slow the program down for metrics that are not necessary?

So I was wondering if you thought overriding some of the class' attributes with a @property would be reasonable in this case?

I'll first finish this draft precomputing everything and would implement it when I think I have a good design.

OriolAbril commented 4 years ago

Not sure I understand, you mean so that lp and other sample stats are available but only computed whenever needed?

I think we have not considered it (could be wrong though).

What we currently have for the log_likelihood (it is useful for loo/waic but otherwise requires extra memory and extra computing) group is an argument to from_pymc3, from_pyro... so that users can choose whether or not the group is to be included in the resulting InferenceData. The same could be done for sample_stats and its variables.

rlouf commented 4 years ago

Not sure I understand, you mean so that lp and other sample stats are available but only computed whenever needed?

Yes, I could do that by defining a sample_stats method in my Trace class that overrides the corresponding attribute in InferenceData. Do I make sense?

I think we have not considered it (could be wrong though).

I don't necessarily think you would need to do that on your end.

What we currently have for the log_likelihood (it is useful for loo/waic but otherwise requires extra memory and extra computing) group is an argument to from_pymc3, from_pyro... so that users can choose whether or not the group is to be included in the resulting InferenceData. The same could be done for sample_stats and its variables.

I did not see that, I will check it out!

rlouf commented 4 years ago

Nevermind, I ran some simple benchmarks and it looks like the running time is only marginally affected by all the conversions. I believe I also found a way to get the log_likelihood for free during sampling by making a small change to my compiled logpdf.

Edit: It is free in terms of computation, but it does add complexity in the inference core that I am not willing to add if I don't have to. However, it is possible and might make sense to get it when I compute the value of deterministic variables. I guess that's one issue with having too much freedom 🤷‍♂️ I'm happy to be working on this know, it helps clarifying both the API and parts of the internals.

rlouf commented 4 years ago

I can confirm that all plots / stats that take an InferenceData instance as an input and do not require log_likelihood work fine.

rlouf commented 4 years ago

Computation of log_likelihood working, moving on to include the warmup trace. @OriolAbril How do you think I should go about storing the warmup info (like step size during dual averaging)?

Edit: never mind, I added it is available internally. I am now storing everything in my format as a namedtuple that is an attribute of Trace. I expose every element that has a counterpart in ArviZ using the @property decorator.

rlouf commented 4 years ago

Just implemented concatenation. I will now add append and leave the prior/posterior predictive data for when I clean their api in mcx.

rlouf commented 4 years ago

@OriolAbril if you want to have a look, the magic happens in trace.py. Everything that should work with ArviZ does, so does the (inplace) addition. I just need to add the append method, improve the documentation and refactor the code a little before merging. I'll handle prior and predictive samples in a later PR.

rlouf commented 4 years ago

Found a performance issue linked to the fact that lax.scan and the for loop do not output the chain in the same format. Reported in #41. I found a fix, but will implement in another PR.