aesara-devs / aehmc

An HMC/NUTS implementation in Aesara
MIT License
33 stars 6 forks source link

Add README #25

Closed rlouf closed 3 years ago

rlouf commented 3 years ago

Closes #16

rlouf commented 3 years ago

Tested with aeppl 0.0.9 and the development branch.

I documented the issue I have with scalar random variables. When I build the following model:

import aesara
import aesara.tensor as at
from aeppl import joint_logprob
from aesara.tensor.random.utils import RandomStream

from aehmc import hmc as hmc

Y_rv = at.random.normal(0, 1, size=(2,), name="Y")

def logprob_fn(y):
    logprob = joint_logprob(Y_rv, {Y_rv: y})
    return logprob

# Build the transition kernel
step_size = at.scalar("step_size")
inverse_mass_matrix = at.vector("inverse_mass_matrix")
num_integration_steps = at.scalar("num_integration_steps", dtype="int32")

srng = RandomStream(seed=1)
kernel = hmc.kernel(srng, logprob_fn, 1e-3, at.ones(2), 10)

# Compile a function that updates the chain
Y_tt = Y_rv.clone()
potential_energy = -logprob_fn(Y_tt)
potential_energy_grad = aesara.grad(potential_energy, wrt=Y_tt)

next_step = kernel(Y_tt, potential_energy, potential_energy_grad)

I get the following error:

Traceback (most recent call last):
  File "/home/remi/projects/aehmc/example.py", line 29, in <module>
    next_step = kernel(Y_tt, potential_energy, potential_energy_grad)
  File "/home/remi/projects/aehmc/aehmc/hmc.py", line 77, in step
    ) = proposal_generator(srng, q, p, potential_energy, potential_energy_grad)
  File "/home/remi/projects/aehmc/aehmc/hmc.py", line 105, in propose
    new_q, new_p, new_potential_energy, new_potential_energy_grad = integrate(
  File "/home/remi/projects/aehmc/aehmc/trajectory.py", line 48, in integrate
    [q, p, energy, energy_grad], _ = aesara.scan(
  File "/home/remi/.virtualenvs/aehmc-nuts/lib/python3.9/site-packages/aesara/scan/basic.py", line 751, in scan
    condition, outputs, updates = utils.get_updates_and_outputs(fn(*args))
  File "/home/remi/projects/aehmc/aehmc/trajectory.py", line 43, in one_step
    new_state = integrator(
  File "/home/remi/projects/aehmc/aehmc/integrators.py", line 67, in one_step
    new_potential_energy_grad = aesara.grad(new_potential_energy, new_position)
  File "/home/remi/.virtualenvs/aehmc-nuts/lib/python3.9/site-packages/aesara/gradient.py", line 615, in grad
    handle_disconnected(elem)
  File "/home/remi/.virtualenvs/aehmc-nuts/lib/python3.9/site-packages/aesara/gradient.py", line 601, in handle_disconnected
    raise DisconnectedInputError(message)
aesara.gradient.DisconnectedInputError:  
Backtrace when that variable is created:

  File "/home/remi/projects/aehmc/example.py", line 29, in <module>
    next_step = kernel(Y_tt, potential_energy, potential_energy_grad)
  File "/home/remi/projects/aehmc/aehmc/hmc.py", line 77, in step
    ) = proposal_generator(srng, q, p, potential_energy, potential_energy_grad)
  File "/home/remi/projects/aehmc/aehmc/hmc.py", line 105, in propose
    new_q, new_p, new_potential_energy, new_potential_energy_grad = integrate(
  File "/home/remi/projects/aehmc/aehmc/trajectory.py", line 48, in integrate
    [q, p, energy, energy_grad], _ = aesara.scan(
  File "/home/remi/.virtualenvs/aehmc-nuts/lib/python3.9/site-packages/aesara/scan/basic.py", line 751, in scan
    condition, outputs, updates = utils.get_updates_and_outputs(fn(*args))
  File "/home/remi/projects/aehmc/aehmc/trajectory.py", line 43, in one_step
    new_state = integrator(
  File "/home/remi/projects/aehmc/aehmc/integrators.py", line 64, in one_step
    new_position = position + a2 * step_size * kinetic_grad

It looks like it may be specific to having only one random variable in the model. Indeed, the following compiles:

Y_rv = at.random.normal(0, 1, size=(2,), name="Y")
Z_rv = at.random.normal(Y_rv, 1)

def logprob_fn(y):
    logprob = joint_logprob(Z_rv, {Y_rv: y, Z_rv: at.as_tensor([.1, .1])})
    return logprob
brandonwillard commented 3 years ago

The scalar issue arises in the one_step function created in velocity_verlet when aesara.grad(new_potential_energy, new_position) is evaluated.

We print the two graphs new_potential_energy and new_position in the following:

aesara.dprint([new_potential_energy, new_position])
# Elemwise{neg,no_inplace} [id A] ''
#  |Sum{acc_dtype=float64} [id B] ''
#    |Assert{msg='sigma > 0'} [id C] ''
#      |Elemwise{sub,no_inplace} [id D] ''
#      | |Elemwise{sub,no_inplace} [id E] ''
#      | | |Elemwise{mul,no_inplace} [id F] ''
#      | | | |InplaceDimShuffle{x} [id G] ''
#      | | | | |TensorConstant{-0.5} [id H]
#      | | | |Elemwise{pow,no_inplace} [id I] ''
#      | | |   |Elemwise{true_div,no_inplace} [id J] ''
#      | | |   | |Elemwise{sub,no_inplace} [id K] ''
#      | | |   | | |Elemwise{add,no_inplace} [id L] ''
#      | | |   | | | |y[t-1] [id M]
#      | | |   | | | |Elemwise{mul,no_inplace} [id N] ''
#      | | |   | | |   |InplaceDimShuffle{x} [id O] ''
#      | | |   | | |   | |TensorConstant{0.001} [id P]
#      | | |   | | |   |Elemwise{add,no_inplace} [id Q] ''
#      | | |   | | |     |Elemwise{mul} [id R] ''
#      | | |   | | |     | |Elemwise{mul,no_inplace} [id S] ''
#      | | |   | | |     | | |InplaceDimShuffle{x} [id T] ''
#      | | |   | | |     | | | |Elemwise{mul} [id U] ''
#      | | |   | | |     | | |   |Elemwise{second,no_inplace} [id V] ''
#      | | |   | | |     | | |   | |Elemwise{mul,no_inplace} [id W] ''
#      | | |   | | |     | | |   | | |TensorConstant{0.5} [id X]
#      | | |   | | |     | | |   | | |dot [id Y] ''
#      | | |   | | |     | | |   | |   |Elemwise{mul,no_inplace} [id Z] ''
#      | | |   | | |     | | |   | |   | |Alloc [id BA] ''
#      | | |   | | |     | | |   | |   | | |TensorConstant{1.0} [id BB]
#      | | |   | | |     | | |   | |   | | |TensorConstant{2} [id BC]
#      | | |   | | |     | | |   | |   | |Elemwise{sub,no_inplace} [id BD] ''
#      | | |   | | |     | | |   | |   |   |<TensorType(float64, vector)> [id BE]
#      | | |   | | |     | | |   | |   |   |Elemwise{mul,no_inplace} [id BF] ''
#      | | |   | | |     | | |   | |   |     |InplaceDimShuffle{x} [id BG] ''
#      | | |   | | |     | | |   | |   |     | |TensorConstant{0.0005} [id BH]
#      | | |   | | |     | | |   | |   |     |<TensorType(float64, vector)> [id BI]
#      | | |   | | |     | | |   | |   |Elemwise{sub,no_inplace} [id BD] ''
#      | | |   | | |     | | |   | |TensorConstant{1.0} [id BJ]
#      | | |   | | |     | | |   |TensorConstant{0.5} [id X]
#      | | |   | | |     | | |Elemwise{sub,no_inplace} [id BD] ''
#      | | |   | | |     | |Alloc [id BA] ''
#      | | |   | | |     |Elemwise{mul,no_inplace} [id BK] ''
#      | | |   | | |       |InplaceDimShuffle{x} [id BL] ''
#      | | |   | | |       | |Elemwise{mul} [id U] ''
#      | | |   | | |       |Elemwise{mul,no_inplace} [id Z] ''
#      | | |   | | |InplaceDimShuffle{x} [id BM] ''
#      | | |   | |   |TensorConstant{0} [id BN]
#      | | |   | |InplaceDimShuffle{x} [id BO] ''
#      | | |   |   |TensorConstant{1} [id BP]
#      | | |   |InplaceDimShuffle{x} [id BQ] ''
#      | | |     |TensorConstant{2} [id BR]
#      | | |InplaceDimShuffle{x} [id BS] ''
#      | |   |Elemwise{log,no_inplace} [id BT] ''
#      | |     |TensorConstant{2.5066282746310002} [id BU]
#      | |InplaceDimShuffle{x} [id BV] ''
#      |   |Elemwise{log,no_inplace} [id BW] ''
#      |     |TensorConstant{1} [id BP]
#      |All [id BX] ''
#        |Elemwise{gt,no_inplace} [id BY] ''
#          |TensorConstant{1} [id BP]
#          |TensorConstant{0.0} [id BZ]
# Elemwise{add,no_inplace} [id CA] ''
#  |y[t-1] [id M]
#  |Elemwise{mul,no_inplace} [id CB] ''
#    |InplaceDimShuffle{x} [id CC] ''
#    | |TensorConstant{0.001} [id P]
#    |Elemwise{add,no_inplace} [id CD] ''
#      |Elemwise{mul} [id CE] ''
#      | |Elemwise{mul,no_inplace} [id CF] ''
#      | | |InplaceDimShuffle{x} [id CG] ''
#      | | | |Elemwise{mul} [id CH] ''
#      | | |   |Elemwise{second,no_inplace} [id CI] ''
#      | | |   | |Elemwise{mul,no_inplace} [id CJ] ''
#      | | |   | | |TensorConstant{0.5} [id X]
#      | | |   | | |dot [id CK] ''
#      | | |   | |   |Elemwise{mul,no_inplace} [id CL] ''
#      | | |   | |   | |Alloc [id CM] ''
#      | | |   | |   | | |TensorConstant{1.0} [id BB]
#      | | |   | |   | | |TensorConstant{2} [id BC]
#      | | |   | |   | |Elemwise{sub,no_inplace} [id CN] ''
#      | | |   | |   |   |<TensorType(float64, vector)> [id BE]
#      | | |   | |   |   |Elemwise{mul,no_inplace} [id CO] ''
#      | | |   | |   |     |InplaceDimShuffle{x} [id CP] ''
#      | | |   | |   |     | |TensorConstant{0.0005} [id BH]
#      | | |   | |   |     |<TensorType(float64, vector)> [id BI]
#      | | |   | |   |Elemwise{sub,no_inplace} [id CN] ''
#      | | |   | |TensorConstant{1.0} [id BJ]
#      | | |   |TensorConstant{0.5} [id X]
#      | | |Elemwise{sub,no_inplace} [id CN] ''
#      | |Alloc [id CM] ''
#      |Elemwise{mul,no_inplace} [id CQ] ''
#        |InplaceDimShuffle{x} [id CR] ''
#        | |Elemwise{mul} [id CH] ''
#        |Elemwise{mul,no_inplace} [id CL] ''

The graph with the ID CA (i.e. new_position) is the variable in which the gradient of new_potential_energy is computed, and a quick search for that ID shows us that the term does not appear in the graph of new_potential_energy.

Since new_potential_energy = potential_fn(new_position), we need to determine why potential_fn is not returning a graph that's a function of its input new_position.

If we set new_position.name = "new_position" and reproduce the graph print-out, we get the following:

new_position.name = "new_position"
aesara.dprint([potential_fn(new_position), new_position])
# Elemwise{neg,no_inplace} [id A] ''
#  |Sum{acc_dtype=float64} [id B] ''
#    |Assert{msg='sigma > 0'} [id C] 'new_position_logprob'
#      |Elemwise{sub,no_inplace} [id D] ''
#      | |Elemwise{sub,no_inplace} [id E] ''
#      | | |Elemwise{mul,no_inplace} [id F] ''
#      | | | |InplaceDimShuffle{x} [id G] ''
#      | | | | |TensorConstant{-0.5} [id H]
#      | | | |Elemwise{pow,no_inplace} [id I] ''
#      | | |   |Elemwise{true_div,no_inplace} [id J] ''
#      | | |   | |Elemwise{sub,no_inplace} [id K] ''
#      | | |   | | |Elemwise{add,no_inplace} [id L] 'new_position'
#      | | |   | | | |y[t-1] [id M]
#      | | |   | | | |Elemwise{mul,no_inplace} [id N] ''
#      | | |   | | |   |InplaceDimShuffle{x} [id O] ''
#      | | |   | | |   | |TensorConstant{0.001} [id P]
#      | | |   | | |   |Elemwise{add,no_inplace} [id Q] ''
#      | | |   | | |     |Elemwise{mul} [id R] ''
#      | | |   | | |     | |Elemwise{mul,no_inplace} [id S] ''
#      | | |   | | |     | | |InplaceDimShuffle{x} [id T] ''
#      | | |   | | |     | | | |Elemwise{mul} [id U] ''
#      | | |   | | |     | | |   |Elemwise{second,no_inplace} [id V] ''
#      | | |   | | |     | | |   | |Elemwise{mul,no_inplace} [id W] ''
#      | | |   | | |     | | |   | | |TensorConstant{0.5} [id X]
#      | | |   | | |     | | |   | | |dot [id Y] ''
#      | | |   | | |     | | |   | |   |Elemwise{mul,no_inplace} [id Z] ''
#      | | |   | | |     | | |   | |   | |Alloc [id BA] ''
#      | | |   | | |     | | |   | |   | | |TensorConstant{1.0} [id BB]
#      | | |   | | |     | | |   | |   | | |TensorConstant{2} [id BC]
#      | | |   | | |     | | |   | |   | |Elemwise{sub,no_inplace} [id BD] ''
#      | | |   | | |     | | |   | |   |   |<TensorType(float64, vector)> [id BE]
#      | | |   | | |     | | |   | |   |   |Elemwise{mul,no_inplace} [id BF] ''
#      | | |   | | |     | | |   | |   |     |InplaceDimShuffle{x} [id BG] ''
#      | | |   | | |     | | |   | |   |     | |TensorConstant{0.0005} [id BH]
#      | | |   | | |     | | |   | |   |     |<TensorType(float64, vector)> [id BI]
#      | | |   | | |     | | |   | |   |Elemwise{sub,no_inplace} [id BD] ''
#      | | |   | | |     | | |   | |TensorConstant{1.0} [id BJ]
#      | | |   | | |     | | |   |TensorConstant{0.5} [id X]
#      | | |   | | |     | | |Elemwise{sub,no_inplace} [id BD] ''
#      | | |   | | |     | |Alloc [id BA] ''
#      | | |   | | |     |Elemwise{mul,no_inplace} [id BK] ''
#      | | |   | | |       |InplaceDimShuffle{x} [id BL] ''
#      | | |   | | |       | |Elemwise{mul} [id U] ''
#      | | |   | | |       |Elemwise{mul,no_inplace} [id Z] ''
#      | | |   | | |InplaceDimShuffle{x} [id BM] ''
#      | | |   | |   |TensorConstant{0} [id BN]
#      | | |   | |InplaceDimShuffle{x} [id BO] ''
#      | | |   |   |TensorConstant{1} [id BP]
#      | | |   |InplaceDimShuffle{x} [id BQ] ''
#      | | |     |TensorConstant{2} [id BR]
#      | | |InplaceDimShuffle{x} [id BS] ''
#      | |   |Elemwise{log,no_inplace} [id BT] ''
#      | |     |TensorConstant{2.5066282746310002} [id BU]
#      | |InplaceDimShuffle{x} [id BV] ''
#      |   |Elemwise{log,no_inplace} [id BW] ''
#      |     |TensorConstant{1} [id BP]
#      |All [id BX] ''
#        |Elemwise{gt,no_inplace} [id BY] ''
#          |TensorConstant{1} [id BP]
#          |TensorConstant{0.0} [id BZ]
# Elemwise{add,no_inplace} [id CA] 'new_position'
#  |y[t-1] [id M]
#  |Elemwise{mul,no_inplace} [id CB] ''
#    |InplaceDimShuffle{x} [id CC] ''
#    | |TensorConstant{0.001} [id P]
#    |Elemwise{add,no_inplace} [id CD] ''
#      |Elemwise{mul} [id CE] ''
#      | |Elemwise{mul,no_inplace} [id CF] ''
#      | | |InplaceDimShuffle{x} [id CG] ''
#      | | | |Elemwise{mul} [id CH] ''
#      | | |   |Elemwise{second,no_inplace} [id CI] ''
#      | | |   | |Elemwise{mul,no_inplace} [id CJ] ''
#      | | |   | | |TensorConstant{0.5} [id X]
#      | | |   | | |dot [id CK] ''
#      | | |   | |   |Elemwise{mul,no_inplace} [id CL] ''
#      | | |   | |   | |Alloc [id CM] ''
#      | | |   | |   | | |TensorConstant{1.0} [id BB]
#      | | |   | |   | | |TensorConstant{2} [id BC]
#      | | |   | |   | |Elemwise{sub,no_inplace} [id CN] ''
#      | | |   | |   |   |<TensorType(float64, vector)> [id BE]
#      | | |   | |   |   |Elemwise{mul,no_inplace} [id CO] ''
#      | | |   | |   |     |InplaceDimShuffle{x} [id CP] ''
#      | | |   | |   |     | |TensorConstant{0.0005} [id BH]
#      | | |   | |   |     |<TensorType(float64, vector)> [id BI]
#      | | |   | |   |Elemwise{sub,no_inplace} [id CN] ''
#      | | |   | |TensorConstant{1.0} [id BJ]
#      | | |   |TensorConstant{0.5} [id X]
#      | | |Elemwise{sub,no_inplace} [id CN] ''
#      | |Alloc [id CM] ''
#      |Elemwise{mul,no_inplace} [id CQ] ''
#        |InplaceDimShuffle{x} [id CR] ''
#        | |Elemwise{mul} [id CH] ''
#        |Elemwise{mul,no_inplace} [id CL] ''

This output makes it clearer that new_position is in the graph output by potential_fn(new_position), but the implication is that it's been cloned, so it's no longer identical to the original new_position. This seems to be the underlying cause of the error.

We can look into this more from the AePPL side, but we can also perform the aesara.grad step in a new, "atomic" variable and replace it with the original new_position graph in new_momentum

rlouf commented 3 years ago

Ok, so if I understand well the new_potential_energy on line 66 is not linked to the new_position defined line 64 but to a clone? I tried changing the one_step function returned by velocity_verlet in integrators.py to the following:

    def one_step(
        position: TensorVariable,
        momentum: TensorVariable,
        potential_energy: TensorVariable,
        potential_energy_grad: TensorVariable,
        step_size: TensorVariable,
    ) -> IntegratorStateType:

        new_momentum = momentum - b1 * step_size * potential_energy_grad

        kinetic_grad = aesara.grad(kinetic_energy_fn(new_momentum), new_momentum)
        new_position = position + a2 * step_size * kinetic_grad

        pos = position.clone()
        npe = potential_fn(pos)
        npeg = aesara.grad(npe, pos)
        nm = new_momentum - b1 * step_size * npeg

        new_momentum = aesara.clone_replace(nm, {pos: new_position})
        new_potential_energy = aesara.clone_replace(npe, {pos: new_position})
        new_potential_energy_grad = aesara.clone_replace(pe, {pos: new_position})

        return (
            new_position,
            new_momentum,
            new_potential_energy,
            new_potential_energy_grad,
        )

And it seems to work. Is that what you meant with your last sentence? If that's the case I can merge the fix temporarily so we can close this issue, but I think it would make more sense to prevent non-atomic random variable from being cloned in aeppl.

This fix is implemented in #26 along with the fix to handle scalar rvs.

brandonwillard commented 3 years ago

Ok, so if I understand well the new_potential_energy on line 66 is not linked to the new_position defined line 64 but to a clone? I tried changing the one_step function returned by velocity_verlet in integrators.py to the following:

Yes, exactly.

brandonwillard commented 3 years ago

~I just put in a fix here: https://github.com/aesara-devs/aeppl/pull/63. I'll merge that as soon as the tests pass.~

Done.

rlouf commented 3 years ago

The fix on aeppl does the trick; I made a small change to support scalars (not python floats) as well. This should be ok now.

codecov[bot] commented 3 years ago

Codecov Report

Merging #25 (8c2f1af) into main (d321748) will not change coverage. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff            @@
##              main       #25   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files            9         9           
  Lines          344       344           
  Branches        14        14           
=========================================
  Hits           344       344           
Impacted Files Coverage Δ
aehmc/metrics.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update d321748...8c2f1af. Read the comment docs.