aesara-devs / aemcmc

AeMCMC is a Python library that automates the construction of samplers for Aesara graphs representing statistical models.
https://aemcmc.readthedocs.io/en/latest/
MIT License
39 stars 11 forks source link

`construct_sampler` does not support transformed observables #89

Open brandonwillard opened 1 year ago

brandonwillard commented 1 year ago

It looks like we need to convert model graphs into AePPL IR and operate on those, because construct_sampler/construct_ir_fgraph only supports observed RandomVariables and not general MeasurableVariables.

An example:

import aesara.tensor as at
import aemcmc

srng = at.random.RandomStream(23920)

X_rv = srng.normal(0, 1, name="X")
Y_rv = 1 + X_rv

obs_rvs_to_values = {Y_rv: at.scalar("y")}

sample_steps, updates, initial_values, nuts_parameters = aemcmc.construct_sampler(
    obs_rvs_to_values, srng
)
aemcmc/basic.py:40: in construct_sampler
    fgraph, obs_rvs_to_values, memo, new_to_old_rvs = construct_ir_fgraph(
aemcmc/rewriting.py:91: in construct_ir_fgraph
    obs_rvs_to_values = {memo[k]: v for k, v in obs_rvs_to_values.items()}
aemcmc/rewriting.py:91: in <dictcomp>
    obs_rvs_to_values = {memo[k]: v for k, v in obs_rvs_to_values.items()}
E   KeyError: Elemwise{add,no_inplace}.0
rlouf commented 1 year ago

Yes. AeMCMC will also need the extra information that we'll add to MeasurableVariable eventually.