Open chrism0dwk opened 5 years ago
FWIW, a way to do it today is to write a new TransitionKernel
kernel, something like this:
class Gibbs(tfp.mcmc.TransitionKernel):
def __init__(self, target_log_prob_fn, make_kernel_fns):
self._target_log_prob_fn = target_log_prob_fn
self._make_kernel_fns = make_kernel_fns
def is_calibrated(self):
return True
def one_step(self, state, _):
for i, make_kernel_fn in enumerate(self._make_kernel_fns):
def _target_log_prob_fn_part(state_part):
state[i] = state_part
return self._target_log_prob_fn(*state)
kernel = make_kernel_fn(_target_log_prob_fn_part)
state[i], _ = kernel.one_step(state[i], kernel.bootstrap_results(state[i]))
return state, ()
def bootstrap_results(self, state):
return ()
Then you just pass it a list of kernel creation functions (that create a transition kernel from a target_log_prob_fn
) that correspond to the elements of the state list.
Few caveats about this snippet:
Many thanks for the solution! I've been trying to extend it to relax the requirement to recompute the current (possibly expensive) log target density at each iteration (using the kernel.bootstrap_results call). Really, we almost need the log target density to be returned with the state tensor rather than bundled in with the heterogeneous "*Results" namedtuples that you're discarding above.
def advance_target_log_prob(step_results, i):
next_step = (i+1) % len(step_results)
step_results[next_step].accepted_results._replace(
target_log_prob=step_results[i].accepted_results.target_log_prob)
return step_results
...
def one_step(self, state, step_results):
for i, make_kernel_fn in enumerate(self._make_kernel_fns):
def _target_log_prob_fn_part(state_part):
state[i] = state_part
return self._target_log_prob_fn(*state)
kernel = make_kernel_fn(_target_log_prob_fn_part)
state[i], step_results[i] = kernel.one_step(state[i], step_results[i])
step_results = advance_target_log_prob(step_results)
return state, ()
...
However, this doesn't work due to losing the reference to the target_log_prob
tf.tensor
due to the named_tuple._replace
call in advance_target_log_prob
. Any suggestions from the TF gurus out there?
Chris
[edit]
A solution that works, advancing the target_log_prob for a RandomWalkMetropolis schema without the unnecessary recalculation of the target density. The copying of the results namedtuples is inelegant, but I guess only runs once:
def advance_target_log_prob(next_results, previous_results):
if isinstance(next_results.accepted_results, previous_results.accepted_results.__class__):
next_accepted = next_results.accepted_results._replace(
target_log_prob=previous_results.accepted_results.target_log_prob)
return next_results._replace(accepted_results=next_accepted)
return None
...
def one_step(self, state, step_results):
prev_step = np.roll(np.arange(len(step_results)), 1)
for i, make_kernel_fn in enumerate(self._make_kernel_fns):
def _target_log_prob_fn_part(state_part):
state[i] = state_part
return self._target_log_prob_fn(*state)
kernel = make_kernel_fn(_target_log_prob_fn_part)
results = advance_target_log_prob(step_results[i], step_results[prev_result[i]]) or kernel.bootstrap_results(state[i])
state[i], step_results[i] = kernel.one_step(state[i], results)
return state, step_results
...
If you want to make that kind of optimization, you need to be careful that the kernel in question isn't storing other aspects of the target_log_prob_fn
's evaluation: e.g. HMC kernel will also store the gradients, so if you're only going to replace the target_log_prob
for HMC's results, you'll be breaking its invariants. I don't have a good solution for this conundrum, it's a limitation of the API.
Bumping this issue with a Request for Comments on a proposed way of Gibbs sampling using the tfp.mcmc
framework. Gibbs sampling is important right now, particularly for stochastic COVID-19 models. If this seems interesting, we'd be happy to contribute to the project.
https://colab.research.google.com/drive/1pHJ4dExxiEg6rTSgV_XErWP4_p7gBpg3?usp=sharing
I think it would be really great to have some sort of Gibbs sampling in TFP. So thanks a lot, @chrism0dwk, for this! I had a look at your suggestion—looks quite nice to me! I have the following comments:
ReplicaExchangeMC
can be used with (#1042) I'm moving / introducing helper methods solving exactly that issue to tfp.mcmc.internal.utils
. They allow to get or update fields in any onion of kernel results. Also, what exactly do you mean by "recent innovations in tfp.mcmc.experimental
? I just hope there is no redundancy to the code in the PR / in tfp.mcmc.replica_exchange_mc
.A long time ago I wrote a Python package for doing Gibbs sampling (this is not an ad).
My strategy there is as follows. I indeed have states essentially being a dictionary and there's a joint distribution class with a conditional_factory()
method. joint.conditional_factory(sub_state=5)
would give you a new (unnormalized) distribution object with the sub_state
part set to the value of 5
. One (systematic scan) Gibbs iteration would then be as easy as (self
is a GibbsSampler
object)
for var in self._pdf.variables:
# update the PDF objects of all subsamplers with the current state
self._update_conditional_pdf_params()
# draw new sample from one of the conditional PDFs
new = self.subsamplers[var].sample()
# update full state held by Gibbs sampler object
self._update_state(**{var: new})
where self.subsamplers
would hold the inner transition kernels with a pdf
attribute given by above-mentioned conditional distribution object.
I just thought I quickly describe this here as an alternative way to set this up. This worked well for me and there's no need for the slightly awkward (no offense :wink:) replacing of the target log probability, as you do in your implementation.
@simeoncarstens Thanks for your comments! Good to know there are helper methods (going into | in) tfp.mcmc.internal.utils
-- I'll take a look as and when they're there. The helper function in tfp.mcmc.experimental
I was referring to is here -- basically the recursive strategy I guess we've all taken.
The reason for the messing about with the target log prob function was really to try to keep close to the overall pattern -- very happy to see if there's another way!
Actually, another drawback of my approach is that there is no way to forward the whole chain state to an inner_kernel
if a global-state-dependent update is required, e.g. like your pure Gibbs sampling approach above, or if a clever (and probably model-specific) MH kernel needed a global view of the chain to perform sensible proposals for its specific state_elem
. Not sure (yet) how that could be achieved...
Thanks for putting the notebook together!
At first glance I'm most worried by the line that sets the inner Kernel's target log prob fn. I think instead this may need to happen in the constructor (like TransformedTransitionKernel), or be given as a kernel builder fn, which iirc is how replica exchange works. Interestingly you could pass the full state into such a kernel builder fn, maybe that resolves the question about cross state element dependencies. I'd be quite excited to get something into TFP to better support gibbs.
On Thu, Aug 13, 2020, 1:34 PM Chris Jewell notifications@github.com wrote:
Actually, another drawback of my approach is that there is no way to forward the whole chain state to an inner_kernel if a global-state-dependent update is required, e.g. like your pure Gibbs sampling approach above, or if a clever (and probably model-specific) MH kernel needed a global view of the chain to perform sensible proposals for its specific state_elem. Not sure (yet) how that could be achieved...
— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/495#issuecomment-673610684, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI6EZUBFIUGGAW3G7JTSAQPZZANCNFSM4IDQXIXA .
Bump...just to say I've not detached myself from this -- just busy!
Sometime when sampling from complex models, it would be handy to provide a list of kernels to tfp.mcmc.sample_chain instead of just one. As a simple example, suppose I have a joint probability distribution, with density function $\pi(\theta, \phi, \psi)$ where $-\infty < \theta, \phi < \infty$ and $\psi \in {-1, 0, 1}$. Here, I might need two kernels: an HMC for $\theta, \phi$ and a Metropolis Hastings for $\psi$.
Could we have a feature for composing such updates, as currently it is not obvious how to achieve this?
Thanks,
Chris