pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.09k stars 227 forks source link

How can I gibbs before HMC/NUTS? #1812

Closed disadone closed 3 weeks ago

disadone commented 4 weeks ago

I am currently doing a work using HMCGibbs. I found that it always sample several times with model part for NUTS or HMC and then runs into the gibbs_fn. However, my program need to apply gibbs_fn first and skip all those definitions on distirbutions related to gibbs_site and variables hmc_site are initialized defined.

Is it possible? It seems that HMCGibbs does not support such order.

https://github.com/pyro-ppl/numpyro/blob/401e364c323aed35ca3235b5c92971b7449dab85/numpyro/infer/hmc_gibbs.py#L166-L170

A minimal example could be like this:

from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMCGibbs

def model():
    x = numpyro.sample("x", dist.Normal(0.0, 2.0))
    y = numpyro.sample("y", dist.Normal(0.0, 2.0))
    numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))

def gibbs_fn(rng_key, gibbs_sites, hmc_sites): # NEED run first
    y = hmc_sites['y'] # NEED: initialized first not sample from model
    x = gibbs_sites['x']
    new_x = dist.Normal(0.8 * (1-y), jnp.sqrt(0.8)).sample(rng_key)
    return {'x': x+new_x}
fehiepsi commented 4 weeks ago

I guess we can add a flag to control such behavior. Based on the flag, we can switch the order of operators in HMCGibbs.sample

disadone commented 3 weeks ago

Do you think it would be easy? I wish try it first by modifying HMCGibbs.

fehiepsi commented 3 weeks ago

Currently, in HMCGibbs.sample, we do gibbs update first (your ref link above) and run HMC.sample after that. It seems that this is the behavior that you want.

Could you clarify your comments here?

def gibbs_fn(rng_key, gibbs_sites, hmc_sites): # NEED run first
    y = hmc_sites['y'] # NEED: initialized first not sample from model

I guess you don't want to use hmc_sites['y'] from the previous MCMC step? If so, you can do y = something_else.

disadone commented 3 weeks ago

Yes, I do not want to hmc_sites['y']. I found the value could be overridden with the init_param value in MCMC if I switch the hmc and gibbs order as shown here.

    def sample(self, state, model_args, model_kwargs):
        model_kwargs = {} if model_kwargs is None else model_kwargs
        rng_key, rng_gibbs = random.split(state.rng_key)

        def potential_fn(z_gibbs, z_hmc):
            return self.inner_kernel._potential_fn_gen(
                *model_args, _gibbs_sites=z_gibbs, **model_kwargs
            )(z_hmc)

        z_gibbs = {k: v for k, v in state.z.items() if k not in state.hmc_state.z}
        z_hmc = {k: v for k, v in state.z.items() if k in state.hmc_state.z}
        model_kwargs_ = model_kwargs.copy()
        model_kwargs_["_gibbs_sites"] = z_gibbs

        z_gibbs = self._gibbs_fn(
            rng_key=rng_gibbs, gibbs_sites=z_gibbs, hmc_sites=z_hmc
        ) # switch the order of z_gibbs and z_hmc

        z_hmc = self.inner_kernel.postprocess_fn(model_args, model_kwargs_)(z_hmc)

        if self.inner_kernel._forward_mode_differentiation:
            pe = potential_fn(z_gibbs, state.hmc_state.z)
            z_grad = jacfwd(partial(potential_fn, z_gibbs))(state.hmc_state.z)
        else:
            pe, z_grad = value_and_grad(partial(potential_fn, z_gibbs))(
                state.hmc_state.z
            )
        hmc_state = state.hmc_state._replace(z_grad=z_grad, potential_energy=pe)

        model_kwargs_["_gibbs_sites"] = z_gibbs
        hmc_state = self.inner_kernel.sample(hmc_state, model_args, model_kwargs_)

        z = {**z_gibbs, **hmc_state.z}

        return HMCGibbsState(z, hmc_state, rng_key)

I just wonder whether there is any unexpected side effects if I turn the sample function like this.

fehiepsi commented 3 weeks ago

What did you switch? I guess I misunderstood your question. We do gibbs first and hmc update later, which seems like what you want.

disadone commented 3 weeks ago

What did you switch? I guess I misunderstood your question. We do gibbs first and hmc update later, which seems like what you want.

Sorry for confusing. The order of these sentences:

z_gibbs = self._gibbs_fn(
            rng_key=rng_gibbs, gibbs_sites=z_gibbs, hmc_sites=z_hmc
        ) # switch the order of z_gibbs and z_hmc

z_hmc = self.inner_kernel.postprocess_fn(model_args, model_kwargs_)(z_hmc)

In the original file, without modification

z_hmc = self.inner_kernel.postprocess_fn(model_args, model_kwargs_)(z_hmc)

z_gibbs = self._gibbs_fn(
            rng_key=rng_gibbs, gibbs_sites=z_gibbs, hmc_sites=z_hmc
        ) # switch the order of z_gibbs and z_hmc

z_hmc will not work through the sample part first and then pass it to self.gibbs_fn in the modified file. I write a print in self-defined model at last and find that self.inner_kernel.postprocess_fn could trig model and change the z_hmc value. Though it seems that postprocess_fn is for postprocess not trigging sampling……

fehiepsi commented 3 weeks ago

The postprocess_fn is necessary to make sure that hmc samples are in the correct domain for the gibbs_fn to condition on. In most cases, it will transform unconstrained samples into constrained samples without triggering the model. But if your model has stochastic support, it is necessary to run the model to perform the transform correctly.

disadone commented 3 weeks ago

Thank you, I understand the point!