aesara-devs / aehmc

An HMC/NUTS implementation in Aesara
MIT License
33 stars 6 forks source link

Refactor the kernel and adaptation API #66

Closed rlouf closed 2 years ago

rlouf commented 2 years ago

In this PR we refactor the kernel and adaptation to be able to run the adaptation in an interactive way. Adaptation is currently exposed as a single function which performs the full adaptation and returns the parameter values:

import aesara.tensor as at
from aesatra.tensor.var import TensorVariable
from aehmc import nuts, window_adaptation

srng = at.random.RandomStream(seed=0)
Y_rv = srng.normal(1, 2)

def logprob_fn(y: TensorVariable):
    logprob = joint_logprob({Y_rv: y})
    return logprob

def kernel_factory(inverse_mass_matrix: TensorVariable):
    return nuts.kernel(
        srng,
        logprob_fn,
        inverse_mass_matrix,
    )

y_vv = Y_rv.clone()
initial_state = nuts.new_state(y_vv, logprob_fn)

state, (step_size, inverse_mass_matrix), updates = window_adaptation.run(
    kernel_factory, initial_state, num_steps=1000
)

However, this kind of monolithic process is not suited for instance when the NUTS kernel is part of a Metropolis-within-Gibbs scheme and the value of some values is not updated using NUTS. We should provide instead an adaptation kernel, so that an update step would (roughly) look like:

new_state, updates_step = nuts_kernel(*state, step_size, inverse_mass_matrix)
new_adaptation_state, step_size, inverse_mass_matrix, updates_adapt = adapt_kernel(
    adaptation_state, new_state
)

This way we should be in a position where we can completely decouple NUTS kernel and the adaptation schemes in aemcmc. This PR was partially motivated by https://github.com/aesara-devs/aemcmc/pull/35 and a professional project. Closes #65

rlouf commented 2 years ago

Scan's validate_inner_graph returns an error whose origin I can't completely understand when running test_warmup_vector while building the . Maybe writing about it will help me find the reason.

The issue is with the scan loop within the run function in window_adaptation.py. If I add a breakpoint after intializing the warmup state and before entering the scan loop I observe the following:

(da_state, mm_state), parameters = adapt_init(initial_state)
breakpoint()

>>> mm_state[0].type  # mean
# TensorType(float64, (None,))
>>> mm_state[1].type  # m2
# TensorType(float64, (None,))

which is expected since y_vv.type is also TensorType(float64, (None,)). However when I add a breakpoint inside the scan loop I observe the following:

# Advance the chain by one step
chain_state, inner_updates = kernel(*chain_state, *parameters)

# Update the warmup state and parameters
warmup_state, parameters = update_adapt(
    warmup_step, warmup_state, parameters, chain_state
)
breakpoint()

>>> warmup_state[1][0].type  # mean
# TensorType(float64, (1,))
>>> warmup_state[1][1].type  # m2
# TensorType(float64, (1,))

The position of the chain, from which mean and m2 are computed, however:

>>> chain_state[0].type
# TensorType(float64, (None,))

Several things that I find confusing:

If I comment out the body of validate_inner_graph I get an error while running the function:

e   valueerror: dimension 0 in rebroadcast's input was supposed to be 1 (got 2 instead)
e   apply node that caused the error: rebroadcast{(0, true)}(if{}.0)
e   toposort index: 99
e   inputs types: [tensortype(float64, (none,))]
e   inputs shapes: [(2,)]
e   inputs strides: [(8,)]
e   inputs values: [array([0., 0.])]
e   outputs clients: [[if{inplace}(subtensor{int64}.0, tensorconstant{(1,) of 0.0}, tensorconstant{(1,) of 0.0}, tensorconstant{0}, rebroadcast{(0, true)}.0, rebroadcast{(0, true)}.0, if{}.2)]]
e   
e   backtrace when the node is created (use aesara flag traceback__limit=n to make it longer):
e     file "/home/remi/.virtualenvs/aehmc/lib/python3.10/site-packages/_pytest/python.py", line 192, in pytest_pyfunc_call
e       result = testfunction(**testargs)
e     file "/home/remi/projects/aehmc/tests/test_hmc.py", line 67, in test_warmup_vector
e       state, (step_size, inverse_mass_matrix), updates = window_adaptation.run(
e     file "/home/remi/projects/aehmc/aehmc/window_adaptation.py", line 72, in run
e       state, updates = aesara.scan(
e     file "/home/remi/.virtualenvs/aehmc/lib/python3.10/site-packages/aesara/scan/basic.py", line 864, in scan
e       raw_inner_outputs = fn(*args)
e     file "/home/remi/projects/aehmc/aehmc/window_adaptation.py", line 55, in one_step
e       warmup_state, parameters = update_adapt(
e     file "/home/remi/projects/aehmc/aehmc/window_adaptation.py", line 166, in update
e       warmup_state, parameters = where_warmup_state(
e     file "/home/remi/projects/aehmc/aehmc/window_adaptation.py", line 180, in where_warmup_state
e       mm_state = ifelse(do_pick_left, left_mm_state, right_mm_state)
e     file "/home/remi/.virtualenvs/aehmc/lib/python3.10/site-packages/aesara/ifelse.py", line 359, in ifelse
e       else_branch_elem = then_branch_elem.type.filter_variable(else_branch_elem)
e   
e   hint: use the aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this apply node.

Which confirms there is an issue with the shape with mean and m2's shapes.

Update

I've tracked down the issue to the slow_final function, and the following line:

_, new_mm_state = mm_init(inverse_mass_matrix.ndim)
breakpoint()
>>> new_mm_state[0].type
# TensorType(float64, (1,))

which is obvious with hindsight, since inverse_mass_matrix.ndim is equal to 1 (we're adapting a diagonal mass matrix).

rlouf commented 2 years ago

Passing the right size to mm_init solves the issue. Two comments:

rlouf commented 2 years ago

The test that currently fails in the CI passes locally with the fix in https://github.com/aesara-devs/aesara/pull/1035. I am overall very happy with this design change that will certainly pay dividends with aemcmc.

codecov[bot] commented 2 years ago

Codecov Report

Merging #66 (5eb4031) into main (dbce929) will not change coverage. The diff coverage is 100.00%.

@@            Coverage Diff            @@
##              main       #66   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files           13        13           
  Lines          531       541   +10     
  Branches        30        31    +1     
=========================================
+ Hits           531       541   +10     
Impacted Files Coverage Δ
aehmc/hmc.py 100.00% <100.00%> (ø)
aehmc/nuts.py 100.00% <100.00%> (ø)
aehmc/window_adaptation.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update dbce929...5eb4031. Read the comment docs.

brandonwillard commented 2 years ago

The first two commits look like they need to be squashed.

rlouf commented 2 years ago

You could argue they can be kept separate, but I don't see a situation in which we would want to reverse this change for NUTS and not for HMC. Squashed.