Closed rlouf closed 4 years ago
I'll try to keep an eye and help wherever possible, I should have time to review and answer questions about InferenceData.
Thanks! I'm doing a lot of code reading and trial/error at the moment, I'll ping you when I have something tangible!
@OriolAbril Several of the metrics in sample_stats
are noticeably slowing down inference for large sample sizes. Some I can probably improve (I have to transpose a couple of arrays to fit the format), but for others like log_probability
I need to perform more calculations.
It can be problematic when users are first iterating in a model and only need the plot_trace
. I can also think of production applications where you would only need the trace and performance is critical. Why slow the program down for metrics that are not necessary?
So I was wondering if you thought overriding some of the class' attributes with a @property
would be reasonable in this case?
I'll first finish this draft precomputing everything and would implement it when I think I have a good design.
Not sure I understand, you mean so that lp
and other sample stats are available but only computed whenever needed?
I think we have not considered it (could be wrong though).
What we currently have for the log_likelihood
(it is useful for loo/waic but otherwise requires extra memory and extra computing) group is an argument to from_pymc3
, from_pyro
... so that users can choose whether or not the group is to be included in the resulting InferenceData. The same could be done for sample_stats and its variables.
Not sure I understand, you mean so that
lp
and other sample stats are available but only computed whenever needed?
Yes, I could do that by defining a sample_stats
method in my Trace
class that overrides the corresponding attribute in InferenceData
.
Do I make sense?
I think we have not considered it (could be wrong though).
I don't necessarily think you would need to do that on your end.
What we currently have for the
log_likelihood
(it is useful for loo/waic but otherwise requires extra memory and extra computing) group is an argument tofrom_pymc3
,from_pyro
... so that users can choose whether or not the group is to be included in the resulting InferenceData. The same could be done for sample_stats and its variables.
I did not see that, I will check it out!
Nevermind, I ran some simple benchmarks and it looks like the running time is only marginally affected by all the conversions. I believe I also found a way to get the log_likelihood
for free during sampling by making a small change to my compiled logpdf.
Edit: It is free in terms of computation, but it does add complexity in the inference core that I am not willing to add if I don't have to. However, it is possible and might make sense to get it when I compute the value of deterministic variables. I guess that's one issue with having too much freedom 🤷♂️ I'm happy to be working on this know, it helps clarifying both the API and parts of the internals.
I can confirm that all plots / stats that take an InferenceData
instance as an input and do not require log_likelihood
work fine.
Computation of log_likelihood
working, moving on to include the warmup trace. @OriolAbril How do you think I should go about storing the warmup info (like step size during dual averaging)?
Edit: never mind, I added it is available internally. I am now storing everything in my format as a namedtuple that is an attribute of Trace
. I expose every element that has a counterpart in ArviZ using the @property
decorator.
Just implemented concatenation. I will now add append
and leave the prior/posterior predictive data for when I clean their api in mcx.
@OriolAbril if you want to have a look, the magic happens in trace.py. Everything that should work with ArviZ does, so does the (inplace) addition. I just need to add the append
method, improve the documentation and refactor the code a little before merging. I'll handle prior and predictive samples in a later PR.
Found a performance issue linked to the fact that lax.scan
and the for
loop do not output the chain in the same format. Reported in #41. I found a fix, but will implement in another PR.
ArviZ is the best package for the exploratory analysis of models, it makes sense to provide a seamless integration with the library. I chose to implement the
Trace
object as a subclass of Arviz'sInferenceData
so MCX traces can be directly used in ArviZ. There is a loss of information in that translation since the sampler returns pretty much all there is to know about the process, but we can get that information (say for debugging) using iterative sampling.lax.scan
loop;io_numpyro.py
;InferenceData
instance with chain positions. Test plotting.log_likelihood
as a@property
(it requires extra computation).HMCInfo
append
method to add a single sample to the trace. Should be fast.