tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.24k stars 1.09k forks source link

Supply list of kernels to tfp.mcmc.sample_chain #495

Open chrism0dwk opened 5 years ago

chrism0dwk commented 5 years ago

Sometime when sampling from complex models, it would be handy to provide a list of kernels to tfp.mcmc.sample_chain instead of just one. As a simple example, suppose I have a joint probability distribution, with density function $\pi(\theta, \phi, \psi)$ where $-\infty < \theta, \phi < \infty$ and $\psi \in {-1, 0, 1}$. Here, I might need two kernels: an HMC for $\theta, \phi$ and a Metropolis Hastings for $\psi$.

Could we have a feature for composing such updates, as currently it is not obvious how to achieve this?

Thanks,

Chris

SiegeLordEx commented 5 years ago

FWIW, a way to do it today is to write a new TransitionKernel kernel, something like this:

class Gibbs(tfp.mcmc.TransitionKernel):
    def __init__(self, target_log_prob_fn, make_kernel_fns):
        self._target_log_prob_fn = target_log_prob_fn
        self._make_kernel_fns = make_kernel_fns

    def is_calibrated(self):
        return True

    def one_step(self, state, _):
        for i, make_kernel_fn in enumerate(self._make_kernel_fns):
            def _target_log_prob_fn_part(state_part):
                state[i] = state_part
                return self._target_log_prob_fn(*state)
            kernel = make_kernel_fn(_target_log_prob_fn_part)
            state[i], _ = kernel.one_step(state[i], kernel.bootstrap_results(state[i]))
        return state, ()

    def bootstrap_results(self, state):
        return ()

Then you just pass it a list of kernel creation functions (that create a transition kernel from a target_log_prob_fn) that correspond to the elements of the state list.

Few caveats about this snippet:

chrism0dwk commented 5 years ago

Many thanks for the solution! I've been trying to extend it to relax the requirement to recompute the current (possibly expensive) log target density at each iteration (using the kernel.bootstrap_results call). Really, we almost need the log target density to be returned with the state tensor rather than bundled in with the heterogeneous "*Results" namedtuples that you're discarding above.

def advance_target_log_prob(step_results, i):
       next_step = (i+1) % len(step_results)
       step_results[next_step].accepted_results._replace(
              target_log_prob=step_results[i].accepted_results.target_log_prob)
       return step_results
...
    def one_step(self, state, step_results):
        for i, make_kernel_fn in enumerate(self._make_kernel_fns):
            def _target_log_prob_fn_part(state_part):
                state[i] = state_part
                return self._target_log_prob_fn(*state)
            kernel = make_kernel_fn(_target_log_prob_fn_part)
            state[i], step_results[i] = kernel.one_step(state[i], step_results[i])
                        step_results = advance_target_log_prob(step_results)
        return state, ()
...

However, this doesn't work due to losing the reference to the target_log_prob tf.tensor due to the named_tuple._replace call in advance_target_log_prob. Any suggestions from the TF gurus out there?

Chris

[edit]

A solution that works, advancing the target_log_prob for a RandomWalkMetropolis schema without the unnecessary recalculation of the target density. The copying of the results namedtuples is inelegant, but I guess only runs once:


def advance_target_log_prob(next_results, previous_results):
    if isinstance(next_results.accepted_results, previous_results.accepted_results.__class__):
        next_accepted = next_results.accepted_results._replace(
            target_log_prob=previous_results.accepted_results.target_log_prob)
        return next_results._replace(accepted_results=next_accepted)
    return None
...
    def one_step(self, state, step_results):
                prev_step = np.roll(np.arange(len(step_results)), 1)
        for i, make_kernel_fn in enumerate(self._make_kernel_fns):
            def _target_log_prob_fn_part(state_part):
                state[i] = state_part
                return self._target_log_prob_fn(*state)
            kernel = make_kernel_fn(_target_log_prob_fn_part)
                        results = advance_target_log_prob(step_results[i], step_results[prev_result[i]]) or kernel.bootstrap_results(state[i])
            state[i], step_results[i] = kernel.one_step(state[i], results)
        return state, step_results
...
SiegeLordEx commented 5 years ago

If you want to make that kind of optimization, you need to be careful that the kernel in question isn't storing other aspects of the target_log_prob_fn's evaluation: e.g. HMC kernel will also store the gradients, so if you're only going to replace the target_log_prob for HMC's results, you'll be breaking its invariants. I don't have a good solution for this conundrum, it's a limitation of the API.

chrism0dwk commented 4 years ago

Bumping this issue with a Request for Comments on a proposed way of Gibbs sampling using the tfp.mcmc framework. Gibbs sampling is important right now, particularly for stochastic COVID-19 models. If this seems interesting, we'd be happy to contribute to the project.

https://colab.research.google.com/drive/1pHJ4dExxiEg6rTSgV_XErWP4_p7gBpg3?usp=sharing

simeoncarstens commented 4 years ago

I think it would be really great to have some sort of Gibbs sampling in TFP. So thanks a lot, @chrism0dwk, for this! I had a look at your suggestion—looks quite nice to me! I have the following comments:

A long time ago I wrote a Python package for doing Gibbs sampling (this is not an ad). My strategy there is as follows. I indeed have states essentially being a dictionary and there's a joint distribution class with a conditional_factory() method. joint.conditional_factory(sub_state=5) would give you a new (unnormalized) distribution object with the sub_state part set to the value of 5. One (systematic scan) Gibbs iteration would then be as easy as (self is a GibbsSampler object)

for var in self._pdf.variables:
    # update the PDF objects of all subsamplers with the current state
    self._update_conditional_pdf_params()
    # draw new sample from one of the conditional PDFs
    new = self.subsamplers[var].sample()
    # update full state held by Gibbs sampler object
    self._update_state(**{var: new})

where self.subsamplers would hold the inner transition kernels with a pdf attribute given by above-mentioned conditional distribution object.

I just thought I quickly describe this here as an alternative way to set this up. This worked well for me and there's no need for the slightly awkward (no offense :wink:) replacing of the target log probability, as you do in your implementation.

chrism0dwk commented 4 years ago

@simeoncarstens Thanks for your comments! Good to know there are helper methods (going into | in) tfp.mcmc.internal.utils -- I'll take a look as and when they're there. The helper function in tfp.mcmc.experimental I was referring to is here -- basically the recursive strategy I guess we've all taken.

The reason for the messing about with the target log prob function was really to try to keep close to the overall pattern -- very happy to see if there's another way!

chrism0dwk commented 4 years ago

Actually, another drawback of my approach is that there is no way to forward the whole chain state to an inner_kernel if a global-state-dependent update is required, e.g. like your pure Gibbs sampling approach above, or if a clever (and probably model-specific) MH kernel needed a global view of the chain to perform sensible proposals for its specific state_elem. Not sure (yet) how that could be achieved...

brianwa84 commented 4 years ago

Thanks for putting the notebook together!

At first glance I'm most worried by the line that sets the inner Kernel's target log prob fn. I think instead this may need to happen in the constructor (like TransformedTransitionKernel), or be given as a kernel builder fn, which iirc is how replica exchange works. Interestingly you could pass the full state into such a kernel builder fn, maybe that resolves the question about cross state element dependencies. I'd be quite excited to get something into TFP to better support gibbs.

On Thu, Aug 13, 2020, 1:34 PM Chris Jewell notifications@github.com wrote:

Actually, another drawback of my approach is that there is no way to forward the whole chain state to an inner_kernel if a global-state-dependent update is required, e.g. like your pure Gibbs sampling approach above, or if a clever (and probably model-specific) MH kernel needed a global view of the chain to perform sensible proposals for its specific state_elem. Not sure (yet) how that could be achieved...

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/495#issuecomment-673610684, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI6EZUBFIUGGAW3G7JTSAQPZZANCNFSM4IDQXIXA .

chrism0dwk commented 4 years ago

Bump...just to say I've not detached myself from this -- just busy!