aesara-devs / aemcmc

AeMCMC is a Python library that automates the construction of samplers for Aesara graphs representing statistical models.
https://aemcmc.readthedocs.io/en/latest/
MIT License
39 stars 11 forks source link

Add a utility function to sample using `scan` #80

Open rlouf opened 2 years ago

rlouf commented 2 years ago

Using the sampling steps built by AeMCMC in a scan loop is not straightforward:

import aesara
import aemcmc

sample_steps, sample_updates, initial_values = aemcmc.construct_sampler(
    {Y_rv: y_tt}, srng
)

to_sample_rvs: List[TensorVariable]
inputs = [initial_values[rv] for rv in to_sample_rvs]
outputs = [sample_steps[rv] for rv in to_sample_rvs]

def step_fn(*values):
    from aesara.compile.function.pfunc import rebuild_collect_shared

    vv_to_values = {inputs[i]: val for i, val in enumerate(values)}

    _, new_values, [_, new_updates, _, _] = rebuild_collect_shared(
        outputs, inputs=inputs, replace=vv_to_values, updates=sample_updates
    )

    return new_values, new_updates

n_samples = at.iscalar("n_samples")
outputs, updates = aesara.scan(step_fn, outputs_info=inputs, n_steps=n_samples)

sample_fn = aesara.function(inputs + [n_samples], outputs, updates=updates)

but easily generalizable. We should implement a utility function, e.g. aemcmc.sampling_loop which, given the outputs of construct_sampler and a number of iterations n_samples returns a graph that generate n_samples.