Closed rlouf closed 2 years ago
To understand the kind of changes that having several possible samplers (including parametrized samplers) will require, let’s take a non-trivial example of building sampling functions for the Horseshoe prior, taken from AeMCMC’s test suite:
import aesara.tensor as at
from aemcmc.basic import construct_sampler
srng = at.random.RandomStream(0)
X = at.matrix("X")
# Horseshoe `beta_rv`
tau_rv = srng.halfcauchy(0, 1, name="tau")
lmbda_rv = srng.halfcauchy(0, 1, size=X.shape[1], name="lambda")
beta_rv = srng.normal(0, lmbda_rv * tau_rv, size=X.shape[1], name="beta")
a = at.scalar("a")
b = at.scalar("b")
h_rv = srng.gamma(a, b, name="h")
# Negative-binomial regression
eta = X @ beta_rv
p = at.sigmoid(-eta)
Y_rv = srng.nbinom(h_rv, p, name="Y")
y_vv = Y_rv.clone()
y_vv.name = "y"
We observe Y_rv
, and we want to sample from the posterior distribution of tau_rv
, lmbda_rv
, beta_rv
, h_rv
. AeMCMC
currently provides a construct_sampler
function:
sample_steps, updates, initial_values = construct_sampler(srng, {Y_rv: y_vv})
The sample_steps
dictionary maps the random variables to the sampling step that was assigned to them. We can print the graph of the sampling step assigned to lambda_rv
:
import aesara
print(f"Variables to sample: {sample_steps.keys()}\n")
# Variables to sample: dict_keys([tau, lambda, beta, h])
aesara.dprint(sample_steps[lmbda_rv])
# Elemwise{reciprocal,no_inplace} [id A] 'lambda_posterior'
# |Elemwise{sqrt,no_inplace} [id B]
# |exponential_rv{0, (0,), floatX, False}.1 [id C]
# |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FCCB6F5A0A0>) [id D]
# |TensorConstant{[]} [id E]
# |TensorConstant{11} [id F]
# |Elemwise{add,no_inplace} [id G]
# |exponential_rv{0, (0,), floatX, False}.1 [id H]
# | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FCCB6F589E0>) [id I]
# | |TensorConstant{[]} [id J]
# | |TensorConstant{11} [id K]
# | |Elemwise{add,no_inplace} [id L]
# | |InplaceDimShuffle{x} [id M]
# | | |TensorConstant{1} [id N]
# | |Elemwise{reciprocal,no_inplace} [id O]
# | |Elemwise{pow,no_inplace} [id P]
# | |lambda [id Q]
# | |InplaceDimShuffle{x} [id R]
# | |TensorConstant{2} [id S]
# |Elemwise{true_div,no_inplace} [id T]
# |Elemwise{mul,no_inplace} [id U]
# | |Elemwise{mul,no_inplace} [id V]
# | | |InplaceDimShuffle{x} [id W]
# | | | |TensorConstant{0.5} [id X]
# | | |Elemwise{pow,no_inplace} [id Y]
# | | |beta [id Z]
# | | |InplaceDimShuffle{x} [id BA]
# | | |TensorConstant{2} [id BB]
# | |InplaceDimShuffle{x} [id BC]
# | |Elemwise{reciprocal,no_inplace} [id BD]
# | |Elemwise{pow,no_inplace} [id BE]
# | |tau [id BF]
# | |TensorConstant{2} [id BG]
# |InplaceDimShuffle{x} [id BH]
# |TensorConstant{1.0} [id BI]
Samplers update rng state and the caller will need to pass these updates to the compiler later, so we return them as well. It consists of a dictionary that contains the updates of the state of the random number generator that we passed via srng
:
print(updates)
# {RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FCCB6F59B60>): for{cpu,scan_fn}.1}
And finally we pass the initial value variables of the random variables we wish to sample from:
print(initial_values)
# {tau: tau, lambda: lambda, beta: beta, h: h}
We can now easily build the graph for the sampler:
sample_vars = [tau_rv, lmbda_rv, beta_rv, h_rv]
outputs = [sample_steps[rv] for rv in sample_vars]
inputs = [X, a, b, y_vv] + [initial_values[rv] for rv in sample_vars]
# Don't forget the updates!
sampler = aesara.function(
inputs,
outputs,
updates=updates,
)
And we can run the sampler
function in a python loop. Although this works perfectly for the Gibbs samplers, the downstream caller has no high-level information about what transformations were applied to the graph, and what samplers were assigned to the variables. They would have to reverse-enginneer the information based on the graph that they receive. Even without considering NUTS, this becomes problematic the day we return a stream of samplers: how are humans (or machines) to reason about what AeMCMC returns?
Other issues related to information arise with NUTS:
We can simply create sampler types. Let us forget for a minute everything we know about Horseshoe prior. We pass the previous model to AeMCMC
but have no idea what the output sampling steps may be:
sample_steps, updates, initial_values = aemcmc.construct_sampler(srng, {Y_rv: y_vv})
What interface would be like? As a statistician, I would like to get some textual information about the sampling steps, for instance here:
print(sample_steps[lambda_rv])
# "Gibbs sampler"
# - Mathematical equations that describe the sampler
# - Names of the variables we're conditioning on at this step (beta, tau)
# - It does not have any parameter
But AeMCMC
can also return parametrized sampling steps. If NUTS were assigned, I would like (need) to know:
print(sample_steps[lambda_rv])
# "NUTS sampler"
# - What are the parameters? What do we already know about them (shapes, types)?
# - What transformations were applied to the random variables?
# - What other variables are conjointly sampled with NUTS?
A machine caller would also need some of this information, especially whether it will need to provides values for some parameters. We can encapsulate this information in a data structure:
from typing import Option, Dict, Sequence
from dataclasses import dataclass
@dataclass(frozen=True)
class NUTSamplingStep():
sampling_steps: Dict[RandomVariable, TensorVariable]
updates: Dict = None
parameters: Dict[str, TensorVariable]
initial_values: Dict[RandomVariable, TensorVariable]
transforms: Optional[Dict[RandomVariable, RVTransform]]
_sampler_name = "NUTS"
def __post_init__(self):
super().__setattr__(
"rvs_to_sample",
tuple(sampling_steps.keys()),
)
def __str__(self):
return f"{_sampler_name}\n",
f"Transforms: {transforms}\n"
f"Variables sampled together: {variables}"
This data structure is created by calling a construct_nuts_sampler
function:
def construct_nuts_sampler(srng, rvs_to_sample, rvs_to_values):
# 1. Initialize the parameter values
parameters = {
"step_size": None,
"inverse_mass_matrix": None,
}
# 2. Create initial value variables in original space
initial_values = ....
# 3. Look for transformations and apply them to initial vv
transforms =
# 4. Create the `rp_map`
# 5. Build the logprob graph
# 6. Build the NUTS sampling step
sampling_steps = ...
return NUTSStep(sampling_steps, updates, parameters, initial_values, transforms)
Notice I have grouped the variables assigned to NUTS under the same data structure, a block. This is necessary because we need to know that only one value for the step_size
is needed for the different variables. Extra attention is only needed from machine callers if the sampler is parametrized; we can thus implement the following hierarchy of sampler types: SamplingStep
, ParametrizedSamplingStep
and every sampler inherits from either of these. This can be extended; for instance algorithms in the HMC family all need a step size and inverse mass matrix parameter for instance.
The construct samplers then returns a list of sampling steps:
sampling_steps, updates, initial_values = construct_sampler(srng, {Y_rv: y_vv})
We can add some syntactic sugar, e.g. by having sampling_steps
be a data structure where sampling steps for each random variable can be accessed by key so as to preserve the original design:
sample_vars = [tau_rv, lmbda_rv, beta_rv, h_rv]
sampling_steps, updates, initial_values = construct_sampler(srng, {Y_rv: y_vv})
outputs = [sampling_steps[rv] for rv in sample_vars] # returns e.g. NUTSSamplingStep.sampling_steps[rv]
print(sampling_steps.parameters)
# {"step_size": at.scalar(), "inverse_mass_matrix": at.vector()}
print(sampling_steps)
# Print information about the whole sampler
print(sampling_steps.model)
# Access the representation of the model that was used to build the sampler
By the way, the need to access the graph representation that is used by the samplers means that the transfoms will need to happen outside of joint_logprob
. At the very least the transformations should be applied to random variables.
Base: 97.38% // Head: 97.41% // Increases project coverage by +0.02%
:tada:
Coverage data is based on head (
3a783e4
) compared to base (ed60d94
). Patch coverage: 100.00% of modified lines in pull request are covered.
:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
The NUTS sampler is integrated in construct_sampler
; it uses the previous sampling_step
to condition the logprob. There might be a performance trade-off with the current implementation as I re-initialize the state at every step; hopefully Aesara is able to identify the extra gradient computation is not needed, but that's something we will need to check.
It is currently tested on an example with NUTS-only. Doing so I discovered a bug when the variable is a scalar:
import aesara.tensor as at
srng = at.random.RandomStream(0)
tau_rv = srng.halfcauchy(0, 1, name="tau")
Y_rv = srng.halfcauchy(0, tau_rv, name="Y")
Which raises:
TypeError: Inconsistency in the inner graph of scan 'scan_fn' : an input and an output are associated with the same recurrent state and should have compatible types but have type 'TensorType(float64, (1,))' and 'TensorType(float64, (None,))' respectively.
This is reminiscent of shape issues I encountered in aehmc
itself and I think the issue has to be solved there since it otherwise samples with no difficulty.
Most importantly, I need an example where there's both a Gibbs sampling step / closed form posterior and a NUTS sampling step.
To summarize, in order of priority:
inverse_mass_matrix
when I initialize it. It doesn't solve the issue in Aesara, but this is good practice anyway and the tests pass.This is ready for review:
In this PR we assign the NUTS sampler to the variables that have not been assigned a sampler in
construct_sampler
.More details to come
Closes #47