Closed disadone closed 3 weeks ago
I guess we can add a flag to control such behavior. Based on the flag, we can switch the order of operators in HMCGibbs.sample
Do you think it would be easy? I wish try it first by modifying HMCGibbs
.
Currently, in HMCGibbs.sample
, we do gibbs
update first (your ref link above) and run HMC.sample after that. It seems that this is the behavior that you want.
Could you clarify your comments here?
def gibbs_fn(rng_key, gibbs_sites, hmc_sites): # NEED run first
y = hmc_sites['y'] # NEED: initialized first not sample from model
I guess you don't want to use hmc_sites['y']
from the previous MCMC step? If so, you can do y = something_else
.
Yes, I do not want to hmc_sites['y']. I found the value could be overridden with the init_param
value in MCMC if I switch the hmc
and gibbs
order as shown here.
def sample(self, state, model_args, model_kwargs):
model_kwargs = {} if model_kwargs is None else model_kwargs
rng_key, rng_gibbs = random.split(state.rng_key)
def potential_fn(z_gibbs, z_hmc):
return self.inner_kernel._potential_fn_gen(
*model_args, _gibbs_sites=z_gibbs, **model_kwargs
)(z_hmc)
z_gibbs = {k: v for k, v in state.z.items() if k not in state.hmc_state.z}
z_hmc = {k: v for k, v in state.z.items() if k in state.hmc_state.z}
model_kwargs_ = model_kwargs.copy()
model_kwargs_["_gibbs_sites"] = z_gibbs
z_gibbs = self._gibbs_fn(
rng_key=rng_gibbs, gibbs_sites=z_gibbs, hmc_sites=z_hmc
) # switch the order of z_gibbs and z_hmc
z_hmc = self.inner_kernel.postprocess_fn(model_args, model_kwargs_)(z_hmc)
if self.inner_kernel._forward_mode_differentiation:
pe = potential_fn(z_gibbs, state.hmc_state.z)
z_grad = jacfwd(partial(potential_fn, z_gibbs))(state.hmc_state.z)
else:
pe, z_grad = value_and_grad(partial(potential_fn, z_gibbs))(
state.hmc_state.z
)
hmc_state = state.hmc_state._replace(z_grad=z_grad, potential_energy=pe)
model_kwargs_["_gibbs_sites"] = z_gibbs
hmc_state = self.inner_kernel.sample(hmc_state, model_args, model_kwargs_)
z = {**z_gibbs, **hmc_state.z}
return HMCGibbsState(z, hmc_state, rng_key)
I just wonder whether there is any unexpected side effects if I turn the sample
function like this.
What did you switch? I guess I misunderstood your question. We do gibbs first and hmc update later, which seems like what you want.
What did you switch? I guess I misunderstood your question. We do gibbs first and hmc update later, which seems like what you want.
Sorry for confusing. The order of these sentences:
z_gibbs = self._gibbs_fn(
rng_key=rng_gibbs, gibbs_sites=z_gibbs, hmc_sites=z_hmc
) # switch the order of z_gibbs and z_hmc
z_hmc = self.inner_kernel.postprocess_fn(model_args, model_kwargs_)(z_hmc)
In the original file, without modification
z_hmc = self.inner_kernel.postprocess_fn(model_args, model_kwargs_)(z_hmc)
z_gibbs = self._gibbs_fn(
rng_key=rng_gibbs, gibbs_sites=z_gibbs, hmc_sites=z_hmc
) # switch the order of z_gibbs and z_hmc
z_hmc
will not work through the sample part first and then pass it to self.gibbs_fn
in the modified file.
I write a print
in self-defined model
at last and find that self.inner_kernel.postprocess_fn
could trig model
and change the z_hmc value. Though it seems that postprocess_fn
is for postprocess not trigging sampling……
The postprocess_fn
is necessary to make sure that hmc samples are in the correct domain for the gibbs_fn to condition on. In most cases, it will transform unconstrained samples into constrained samples without triggering the model. But if your model has stochastic support, it is necessary to run the model to perform the transform correctly.
Thank you, I understand the point!
I am currently doing a work using
HMCGibbs
. I found that it always sample several times withmodel
part forNUTS
orHMC
and then runs into thegibbs_fn
. However, my program need to applygibbs_fn
first and skip all those definitions on distirbutions related togibbs_site
and variableshmc_site
are initialized defined.Is it possible? It seems that HMCGibbs does not support such order.
https://github.com/pyro-ppl/numpyro/blob/401e364c323aed35ca3235b5c92971b7449dab85/numpyro/infer/hmc_gibbs.py#L166-L170
A minimal example could be like this: