Closed vanAmsterdam closed 4 years ago
Thanks, @vanAmsterdam! I guess the easiest workaround is to do smps = mcmc.get_samples(group_by_chain=False)
, get likelihoods, then reshape the output to num_chains x num_samples x ...
. The job can be easier if we incorporate an argument batch_ndim
or something like that, so users just need to set batch_ndim=1
to get the current behavior, batch_ndim=2
to get chained likelihood, and batch_ndim=0
to get the likelihood of a single sample.
@neerajprad It is a bit inconvenient for users to work with the reshape stuff and to carry around both non-chained samples and chained samples because some utilities work with chained samples (e.g. diagnostic stuff) while the others work with non-chained version (likelihood, predictive). WDYT?
@vanAmsterdam Now, you can compute log_likelihood by chain by specifying 'batch_ndims=2'.
When checking inference, diagnosing log_likelihoods by chain could provide some insight, e.g. if the posterior has different modes. Currently passing samples that are
group_by_chain
-ed tolog_likelihood
will lead to broadcasting errors. Maybe we can add something like this:the only thing I have been struggling with still is to keep the possibility to pass *args to the model call...