Using the sampling steps built by AeMCMC in a scan loop is not straightforward:
import aesara
import aemcmc
sample_steps, sample_updates, initial_values = aemcmc.construct_sampler(
{Y_rv: y_tt}, srng
)
to_sample_rvs: List[TensorVariable]
inputs = [initial_values[rv] for rv in to_sample_rvs]
outputs = [sample_steps[rv] for rv in to_sample_rvs]
def step_fn(*values):
from aesara.compile.function.pfunc import rebuild_collect_shared
vv_to_values = {inputs[i]: val for i, val in enumerate(values)}
_, new_values, [_, new_updates, _, _] = rebuild_collect_shared(
outputs, inputs=inputs, replace=vv_to_values, updates=sample_updates
)
return new_values, new_updates
n_samples = at.iscalar("n_samples")
outputs, updates = aesara.scan(step_fn, outputs_info=inputs, n_steps=n_samples)
sample_fn = aesara.function(inputs + [n_samples], outputs, updates=updates)
but easily generalizable. We should implement a utility function, e.g. aemcmc.sampling_loop which, given the outputs of construct_sampler and a number of iterations n_samples returns a graph that generate n_samples.
Using the sampling steps built by AeMCMC in a
scan
loop is not straightforward:but easily generalizable. We should implement a utility function, e.g.
aemcmc.sampling_loop
which, given the outputs ofconstruct_sampler
and a number of iterationsn_samples
returns a graph that generaten_samples
.