aesara-devs / aemcmc

AeMCMC is a Python library that automates the construction of samplers for Aesara graphs representing statistical models.
https://aemcmc.readthedocs.io/en/latest/
MIT License
39 stars 11 forks source link

Sample remaining variables with the NUTS sampler #68

Closed rlouf closed 2 years ago

rlouf commented 2 years ago

In this PR we assign the NUTS sampler to the variables that have not been assigned a sampler in construct_sampler.

More details to come

Closes #47

rlouf commented 2 years ago

To understand the kind of changes that having several possible samplers (including parametrized samplers) will require, let’s take a non-trivial example of building sampling functions for the Horseshoe prior, taken from AeMCMC’s test suite:

import aesara.tensor as at
from aemcmc.basic import construct_sampler

srng = at.random.RandomStream(0)

X = at.matrix("X")

# Horseshoe `beta_rv`
tau_rv = srng.halfcauchy(0, 1, name="tau")
lmbda_rv = srng.halfcauchy(0, 1, size=X.shape[1], name="lambda")
beta_rv = srng.normal(0, lmbda_rv * tau_rv, size=X.shape[1], name="beta")

a = at.scalar("a")
b = at.scalar("b")
h_rv = srng.gamma(a, b, name="h")

# Negative-binomial regression
eta = X @ beta_rv
p = at.sigmoid(-eta)
Y_rv = srng.nbinom(h_rv, p, name="Y")

y_vv = Y_rv.clone()
y_vv.name = "y"

We observe Y_rv, and we want to sample from the posterior distribution of tau_rv, lmbda_rv, beta_rv, h_rv. AeMCMC currently provides a construct_sampler function:

sample_steps, updates, initial_values = construct_sampler(srng, {Y_rv: y_vv})

The sample_steps dictionary maps the random variables to the sampling step that was assigned to them. We can print the graph of the sampling step assigned to lambda_rv:

import aesara

print(f"Variables to sample: {sample_steps.keys()}\n")
# Variables to sample: dict_keys([tau, lambda, beta, h])

aesara.dprint(sample_steps[lmbda_rv])
# Elemwise{reciprocal,no_inplace} [id A] 'lambda_posterior'
#  |Elemwise{sqrt,no_inplace} [id B]
#    |exponential_rv{0, (0,), floatX, False}.1 [id C]
#      |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FCCB6F5A0A0>) [id D]
#      |TensorConstant{[]} [id E]
#      |TensorConstant{11} [id F]
#      |Elemwise{add,no_inplace} [id G]
#        |exponential_rv{0, (0,), floatX, False}.1 [id H]
#        | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FCCB6F589E0>) [id I]
#        | |TensorConstant{[]} [id J]
#        | |TensorConstant{11} [id K]
#        | |Elemwise{add,no_inplace} [id L]
#        |   |InplaceDimShuffle{x} [id M]
#        |   | |TensorConstant{1} [id N]
#        |   |Elemwise{reciprocal,no_inplace} [id O]
#        |     |Elemwise{pow,no_inplace} [id P]
#        |       |lambda [id Q]
#        |       |InplaceDimShuffle{x} [id R]
#        |         |TensorConstant{2} [id S]
#        |Elemwise{true_div,no_inplace} [id T]
#          |Elemwise{mul,no_inplace} [id U]
#          | |Elemwise{mul,no_inplace} [id V]
#          | | |InplaceDimShuffle{x} [id W]
#          | | | |TensorConstant{0.5} [id X]
#          | | |Elemwise{pow,no_inplace} [id Y]
#          | |   |beta [id Z]
#          | |   |InplaceDimShuffle{x} [id BA]
#          | |     |TensorConstant{2} [id BB]
#          | |InplaceDimShuffle{x} [id BC]
#          |   |Elemwise{reciprocal,no_inplace} [id BD]
#          |     |Elemwise{pow,no_inplace} [id BE]
#          |       |tau [id BF]
#          |       |TensorConstant{2} [id BG]
#          |InplaceDimShuffle{x} [id BH]
#            |TensorConstant{1.0} [id BI]

Samplers update rng state and the caller will need to pass these updates to the compiler later, so we return them as well. It consists of a dictionary that contains the updates of the state of the random number generator that we passed via srng:

print(updates)
# {RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FCCB6F59B60>): for{cpu,scan_fn}.1}

And finally we pass the initial value variables of the random variables we wish to sample from:

print(initial_values)
# {tau: tau, lambda: lambda, beta: beta, h: h}

We can now easily build the graph for the sampler:

sample_vars = [tau_rv, lmbda_rv, beta_rv, h_rv]

outputs = [sample_steps[rv] for rv in sample_vars]
inputs = [X, a, b, y_vv] + [initial_values[rv] for rv in sample_vars]

# Don't forget the updates!
sampler = aesara.function(
    inputs,
    outputs,
    updates=updates,
)

And we can run the sampler function in a python loop. Although this works perfectly for the Gibbs samplers, the downstream caller has no high-level information about what transformations were applied to the graph, and what samplers were assigned to the variables. They would have to reverse-enginneer the information based on the graph that they receive. Even without considering NUTS, this becomes problematic the day we return a stream of samplers: how are humans (or machines) to reason about what AeMCMC returns?

Other issues related to information arise with NUTS:

  1. NUTS works best with unconstrained variables. We thus need to transform the graph; how do we convey this information?
  2. NUTS needs parameters to run. Downstream callers need to know that if they want to write a sampling loop.
  3. NUTS’s parameters need to go through an adaptation mechanism. How do we provide the update functions for these parameters? How do we let the caller know? (This question can be answered independently, I will leave it aside for now)

We can simply create sampler types. Let us forget for a minute everything we know about Horseshoe prior. We pass the previous model to AeMCMC but have no idea what the output sampling steps may be:

sample_steps, updates, initial_values = aemcmc.construct_sampler(srng, {Y_rv: y_vv})

What interface would be like? As a statistician, I would like to get some textual information about the sampling steps, for instance here:

print(sample_steps[lambda_rv])
# "Gibbs sampler"
# - Mathematical equations that describe the sampler
# - Names of the variables we're conditioning on at this step (beta, tau)
# - It does not have any parameter

But AeMCMC can also return parametrized sampling steps. If NUTS were assigned, I would like (need) to know:

print(sample_steps[lambda_rv])
# "NUTS sampler"
# - What are the parameters? What do we already know about them (shapes, types)?
# - What transformations were applied to the random variables?
# - What other variables are conjointly sampled with NUTS?

A machine caller would also need some of this information, especially whether it will need to provides values for some parameters. We can encapsulate this information in a data structure:

from typing import Option, Dict, Sequence
from dataclasses import dataclass

@dataclass(frozen=True)
class NUTSamplingStep():
    sampling_steps: Dict[RandomVariable, TensorVariable]
    updates: Dict = None
    parameters: Dict[str, TensorVariable]
    initial_values: Dict[RandomVariable, TensorVariable]
    transforms: Optional[Dict[RandomVariable, RVTransform]]
    _sampler_name = "NUTS"

    def __post_init__(self):
        super().__setattr__(
            "rvs_to_sample",
            tuple(sampling_steps.keys()),
        )

    def __str__(self):
        return f"{_sampler_name}\n",
            f"Transforms: {transforms}\n"
            f"Variables sampled together: {variables}"

This data structure is created by calling a construct_nuts_sampler function:

def construct_nuts_sampler(srng, rvs_to_sample, rvs_to_values):
        # 1. Initialize the parameter values
        parameters = {
            "step_size": None,
            "inverse_mass_matrix": None,
        }
        # 2. Create initial value variables in original space
        initial_values = ....
        # 3. Look for transformations and apply them to initial vv
        transforms =
        # 4. Create the `rp_map`
        # 5. Build the logprob graph
        # 6. Build the NUTS sampling step
        sampling_steps = ...

        return NUTSStep(sampling_steps, updates, parameters, initial_values, transforms)

Notice I have grouped the variables assigned to NUTS under the same data structure, a block. This is necessary because we need to know that only one value for the step_size is needed for the different variables. Extra attention is only needed from machine callers if the sampler is parametrized; we can thus implement the following hierarchy of sampler types: SamplingStep, ParametrizedSamplingStep and every sampler inherits from either of these. This can be extended; for instance algorithms in the HMC family all need a step size and inverse mass matrix parameter for instance.

The construct samplers then returns a list of sampling steps:

sampling_steps, updates, initial_values = construct_sampler(srng, {Y_rv: y_vv})

We can add some syntactic sugar, e.g. by having sampling_steps be a data structure where sampling steps for each random variable can be accessed by key so as to preserve the original design:

sample_vars = [tau_rv, lmbda_rv, beta_rv, h_rv]

sampling_steps, updates, initial_values = construct_sampler(srng, {Y_rv: y_vv})
outputs = [sampling_steps[rv] for rv in sample_vars]  # returns e.g. NUTSSamplingStep.sampling_steps[rv]

print(sampling_steps.parameters)
# {"step_size": at.scalar(), "inverse_mass_matrix": at.vector()}

print(sampling_steps)
# Print information about the whole sampler

print(sampling_steps.model)
# Access the representation of the model that was used to build the sampler

By the way, the need to access the graph representation that is used by the samplers means that the transfoms will need to happen outside of joint_logprob. At the very least the transformations should be applied to random variables.

codecov[bot] commented 2 years ago

Codecov Report

Base: 97.38% // Head: 97.41% // Increases project coverage by +0.02% :tada:

Coverage data is based on head (3a783e4) compared to base (ed60d94). Patch coverage: 100.00% of modified lines in pull request are covered.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #68 +/- ## ========================================== + Coverage 97.38% 97.41% +0.02% ========================================== Files 9 9 Lines 613 619 +6 Branches 60 58 -2 ========================================== + Hits 597 603 +6 Misses 5 5 Partials 11 11 ``` | [Impacted Files](https://codecov.io/gh/aesara-devs/aemcmc/pull/68?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None) | Coverage Δ | | |---|---|---| | [aemcmc/basic.py](https://codecov.io/gh/aesara-devs/aemcmc/pull/68/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVtY21jL2Jhc2ljLnB5) | `100.00% <100.00%> (ø)` | | | [aemcmc/nuts.py](https://codecov.io/gh/aesara-devs/aemcmc/pull/68/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-YWVtY21jL251dHMucHk=) | `98.03% <100.00%> (-0.04%)` | :arrow_down: | Help us with your feedback. Take ten seconds to tell us [how you rate us](https://about.codecov.io/nps?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None). Have a feature suggestion? [Share it here.](https://app.codecov.io/gh/feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None)

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

rlouf commented 2 years ago

Progress report

The NUTS sampler is integrated in construct_sampler; it uses the previous sampling_step to condition the logprob. There might be a performance trade-off with the current implementation as I re-initialize the state at every step; hopefully Aesara is able to identify the extra gradient computation is not needed, but that's something we will need to check.

It is currently tested on an example with NUTS-only. Doing so I discovered a bug when the variable is a scalar:

import aesara.tensor as at

srng = at.random.RandomStream(0)
tau_rv = srng.halfcauchy(0, 1, name="tau")
Y_rv = srng.halfcauchy(0, tau_rv, name="Y")

Which raises:

TypeError: Inconsistency in the inner graph of scan 'scan_fn' : an input and an output are associated with the same recurrent state and should have compatible types but have type 'TensorType(float64, (1,))' and 'TensorType(float64, (None,))' respectively.

This is reminiscent of shape issues I encountered in aehmc itself and I think the issue has to be solved there since it otherwise samples with no difficulty.

Most importantly, I need an example where there's both a Gibbs sampling step / closed form posterior and a NUTS sampling step.

To summarize, in order of priority:

rlouf commented 2 years ago

This is ready for review: