Closed rlouf closed 3 years ago
Tested with aeppl
0.0.9 and the development branch.
I documented the issue I have with scalar random variables. When I build the following model:
import aesara
import aesara.tensor as at
from aeppl import joint_logprob
from aesara.tensor.random.utils import RandomStream
from aehmc import hmc as hmc
Y_rv = at.random.normal(0, 1, size=(2,), name="Y")
def logprob_fn(y):
logprob = joint_logprob(Y_rv, {Y_rv: y})
return logprob
# Build the transition kernel
step_size = at.scalar("step_size")
inverse_mass_matrix = at.vector("inverse_mass_matrix")
num_integration_steps = at.scalar("num_integration_steps", dtype="int32")
srng = RandomStream(seed=1)
kernel = hmc.kernel(srng, logprob_fn, 1e-3, at.ones(2), 10)
# Compile a function that updates the chain
Y_tt = Y_rv.clone()
potential_energy = -logprob_fn(Y_tt)
potential_energy_grad = aesara.grad(potential_energy, wrt=Y_tt)
next_step = kernel(Y_tt, potential_energy, potential_energy_grad)
I get the following error:
Traceback (most recent call last):
File "/home/remi/projects/aehmc/example.py", line 29, in <module>
next_step = kernel(Y_tt, potential_energy, potential_energy_grad)
File "/home/remi/projects/aehmc/aehmc/hmc.py", line 77, in step
) = proposal_generator(srng, q, p, potential_energy, potential_energy_grad)
File "/home/remi/projects/aehmc/aehmc/hmc.py", line 105, in propose
new_q, new_p, new_potential_energy, new_potential_energy_grad = integrate(
File "/home/remi/projects/aehmc/aehmc/trajectory.py", line 48, in integrate
[q, p, energy, energy_grad], _ = aesara.scan(
File "/home/remi/.virtualenvs/aehmc-nuts/lib/python3.9/site-packages/aesara/scan/basic.py", line 751, in scan
condition, outputs, updates = utils.get_updates_and_outputs(fn(*args))
File "/home/remi/projects/aehmc/aehmc/trajectory.py", line 43, in one_step
new_state = integrator(
File "/home/remi/projects/aehmc/aehmc/integrators.py", line 67, in one_step
new_potential_energy_grad = aesara.grad(new_potential_energy, new_position)
File "/home/remi/.virtualenvs/aehmc-nuts/lib/python3.9/site-packages/aesara/gradient.py", line 615, in grad
handle_disconnected(elem)
File "/home/remi/.virtualenvs/aehmc-nuts/lib/python3.9/site-packages/aesara/gradient.py", line 601, in handle_disconnected
raise DisconnectedInputError(message)
aesara.gradient.DisconnectedInputError:
Backtrace when that variable is created:
File "/home/remi/projects/aehmc/example.py", line 29, in <module>
next_step = kernel(Y_tt, potential_energy, potential_energy_grad)
File "/home/remi/projects/aehmc/aehmc/hmc.py", line 77, in step
) = proposal_generator(srng, q, p, potential_energy, potential_energy_grad)
File "/home/remi/projects/aehmc/aehmc/hmc.py", line 105, in propose
new_q, new_p, new_potential_energy, new_potential_energy_grad = integrate(
File "/home/remi/projects/aehmc/aehmc/trajectory.py", line 48, in integrate
[q, p, energy, energy_grad], _ = aesara.scan(
File "/home/remi/.virtualenvs/aehmc-nuts/lib/python3.9/site-packages/aesara/scan/basic.py", line 751, in scan
condition, outputs, updates = utils.get_updates_and_outputs(fn(*args))
File "/home/remi/projects/aehmc/aehmc/trajectory.py", line 43, in one_step
new_state = integrator(
File "/home/remi/projects/aehmc/aehmc/integrators.py", line 64, in one_step
new_position = position + a2 * step_size * kinetic_grad
It looks like it may be specific to having only one random variable in the model. Indeed, the following compiles:
Y_rv = at.random.normal(0, 1, size=(2,), name="Y")
Z_rv = at.random.normal(Y_rv, 1)
def logprob_fn(y):
logprob = joint_logprob(Z_rv, {Y_rv: y, Z_rv: at.as_tensor([.1, .1])})
return logprob
The scalar issue arises in the one_step
function created in velocity_verlet
when aesara.grad(new_potential_energy, new_position)
is evaluated.
We print the two graphs new_potential_energy
and new_position
in the following:
aesara.dprint([new_potential_energy, new_position])
# Elemwise{neg,no_inplace} [id A] ''
# |Sum{acc_dtype=float64} [id B] ''
# |Assert{msg='sigma > 0'} [id C] ''
# |Elemwise{sub,no_inplace} [id D] ''
# | |Elemwise{sub,no_inplace} [id E] ''
# | | |Elemwise{mul,no_inplace} [id F] ''
# | | | |InplaceDimShuffle{x} [id G] ''
# | | | | |TensorConstant{-0.5} [id H]
# | | | |Elemwise{pow,no_inplace} [id I] ''
# | | | |Elemwise{true_div,no_inplace} [id J] ''
# | | | | |Elemwise{sub,no_inplace} [id K] ''
# | | | | | |Elemwise{add,no_inplace} [id L] ''
# | | | | | | |y[t-1] [id M]
# | | | | | | |Elemwise{mul,no_inplace} [id N] ''
# | | | | | | |InplaceDimShuffle{x} [id O] ''
# | | | | | | | |TensorConstant{0.001} [id P]
# | | | | | | |Elemwise{add,no_inplace} [id Q] ''
# | | | | | | |Elemwise{mul} [id R] ''
# | | | | | | | |Elemwise{mul,no_inplace} [id S] ''
# | | | | | | | | |InplaceDimShuffle{x} [id T] ''
# | | | | | | | | | |Elemwise{mul} [id U] ''
# | | | | | | | | | |Elemwise{second,no_inplace} [id V] ''
# | | | | | | | | | | |Elemwise{mul,no_inplace} [id W] ''
# | | | | | | | | | | | |TensorConstant{0.5} [id X]
# | | | | | | | | | | | |dot [id Y] ''
# | | | | | | | | | | | |Elemwise{mul,no_inplace} [id Z] ''
# | | | | | | | | | | | | |Alloc [id BA] ''
# | | | | | | | | | | | | | |TensorConstant{1.0} [id BB]
# | | | | | | | | | | | | | |TensorConstant{2} [id BC]
# | | | | | | | | | | | | |Elemwise{sub,no_inplace} [id BD] ''
# | | | | | | | | | | | | |<TensorType(float64, vector)> [id BE]
# | | | | | | | | | | | | |Elemwise{mul,no_inplace} [id BF] ''
# | | | | | | | | | | | | |InplaceDimShuffle{x} [id BG] ''
# | | | | | | | | | | | | | |TensorConstant{0.0005} [id BH]
# | | | | | | | | | | | | |<TensorType(float64, vector)> [id BI]
# | | | | | | | | | | | |Elemwise{sub,no_inplace} [id BD] ''
# | | | | | | | | | | |TensorConstant{1.0} [id BJ]
# | | | | | | | | | |TensorConstant{0.5} [id X]
# | | | | | | | | |Elemwise{sub,no_inplace} [id BD] ''
# | | | | | | | |Alloc [id BA] ''
# | | | | | | |Elemwise{mul,no_inplace} [id BK] ''
# | | | | | | |InplaceDimShuffle{x} [id BL] ''
# | | | | | | | |Elemwise{mul} [id U] ''
# | | | | | | |Elemwise{mul,no_inplace} [id Z] ''
# | | | | | |InplaceDimShuffle{x} [id BM] ''
# | | | | | |TensorConstant{0} [id BN]
# | | | | |InplaceDimShuffle{x} [id BO] ''
# | | | | |TensorConstant{1} [id BP]
# | | | |InplaceDimShuffle{x} [id BQ] ''
# | | | |TensorConstant{2} [id BR]
# | | |InplaceDimShuffle{x} [id BS] ''
# | | |Elemwise{log,no_inplace} [id BT] ''
# | | |TensorConstant{2.5066282746310002} [id BU]
# | |InplaceDimShuffle{x} [id BV] ''
# | |Elemwise{log,no_inplace} [id BW] ''
# | |TensorConstant{1} [id BP]
# |All [id BX] ''
# |Elemwise{gt,no_inplace} [id BY] ''
# |TensorConstant{1} [id BP]
# |TensorConstant{0.0} [id BZ]
# Elemwise{add,no_inplace} [id CA] ''
# |y[t-1] [id M]
# |Elemwise{mul,no_inplace} [id CB] ''
# |InplaceDimShuffle{x} [id CC] ''
# | |TensorConstant{0.001} [id P]
# |Elemwise{add,no_inplace} [id CD] ''
# |Elemwise{mul} [id CE] ''
# | |Elemwise{mul,no_inplace} [id CF] ''
# | | |InplaceDimShuffle{x} [id CG] ''
# | | | |Elemwise{mul} [id CH] ''
# | | | |Elemwise{second,no_inplace} [id CI] ''
# | | | | |Elemwise{mul,no_inplace} [id CJ] ''
# | | | | | |TensorConstant{0.5} [id X]
# | | | | | |dot [id CK] ''
# | | | | | |Elemwise{mul,no_inplace} [id CL] ''
# | | | | | | |Alloc [id CM] ''
# | | | | | | | |TensorConstant{1.0} [id BB]
# | | | | | | | |TensorConstant{2} [id BC]
# | | | | | | |Elemwise{sub,no_inplace} [id CN] ''
# | | | | | | |<TensorType(float64, vector)> [id BE]
# | | | | | | |Elemwise{mul,no_inplace} [id CO] ''
# | | | | | | |InplaceDimShuffle{x} [id CP] ''
# | | | | | | | |TensorConstant{0.0005} [id BH]
# | | | | | | |<TensorType(float64, vector)> [id BI]
# | | | | | |Elemwise{sub,no_inplace} [id CN] ''
# | | | | |TensorConstant{1.0} [id BJ]
# | | | |TensorConstant{0.5} [id X]
# | | |Elemwise{sub,no_inplace} [id CN] ''
# | |Alloc [id CM] ''
# |Elemwise{mul,no_inplace} [id CQ] ''
# |InplaceDimShuffle{x} [id CR] ''
# | |Elemwise{mul} [id CH] ''
# |Elemwise{mul,no_inplace} [id CL] ''
The graph with the ID CA
(i.e. new_position
) is the variable in which the gradient of new_potential_energy
is computed, and a quick search for that ID shows us that the term does not appear in the graph of new_potential_energy
.
Since new_potential_energy = potential_fn(new_position)
, we need to determine why potential_fn
is not returning a graph that's a function of its input new_position
.
If we set new_position.name = "new_position"
and reproduce the graph print-out, we get the following:
new_position.name = "new_position"
aesara.dprint([potential_fn(new_position), new_position])
# Elemwise{neg,no_inplace} [id A] ''
# |Sum{acc_dtype=float64} [id B] ''
# |Assert{msg='sigma > 0'} [id C] 'new_position_logprob'
# |Elemwise{sub,no_inplace} [id D] ''
# | |Elemwise{sub,no_inplace} [id E] ''
# | | |Elemwise{mul,no_inplace} [id F] ''
# | | | |InplaceDimShuffle{x} [id G] ''
# | | | | |TensorConstant{-0.5} [id H]
# | | | |Elemwise{pow,no_inplace} [id I] ''
# | | | |Elemwise{true_div,no_inplace} [id J] ''
# | | | | |Elemwise{sub,no_inplace} [id K] ''
# | | | | | |Elemwise{add,no_inplace} [id L] 'new_position'
# | | | | | | |y[t-1] [id M]
# | | | | | | |Elemwise{mul,no_inplace} [id N] ''
# | | | | | | |InplaceDimShuffle{x} [id O] ''
# | | | | | | | |TensorConstant{0.001} [id P]
# | | | | | | |Elemwise{add,no_inplace} [id Q] ''
# | | | | | | |Elemwise{mul} [id R] ''
# | | | | | | | |Elemwise{mul,no_inplace} [id S] ''
# | | | | | | | | |InplaceDimShuffle{x} [id T] ''
# | | | | | | | | | |Elemwise{mul} [id U] ''
# | | | | | | | | | |Elemwise{second,no_inplace} [id V] ''
# | | | | | | | | | | |Elemwise{mul,no_inplace} [id W] ''
# | | | | | | | | | | | |TensorConstant{0.5} [id X]
# | | | | | | | | | | | |dot [id Y] ''
# | | | | | | | | | | | |Elemwise{mul,no_inplace} [id Z] ''
# | | | | | | | | | | | | |Alloc [id BA] ''
# | | | | | | | | | | | | | |TensorConstant{1.0} [id BB]
# | | | | | | | | | | | | | |TensorConstant{2} [id BC]
# | | | | | | | | | | | | |Elemwise{sub,no_inplace} [id BD] ''
# | | | | | | | | | | | | |<TensorType(float64, vector)> [id BE]
# | | | | | | | | | | | | |Elemwise{mul,no_inplace} [id BF] ''
# | | | | | | | | | | | | |InplaceDimShuffle{x} [id BG] ''
# | | | | | | | | | | | | | |TensorConstant{0.0005} [id BH]
# | | | | | | | | | | | | |<TensorType(float64, vector)> [id BI]
# | | | | | | | | | | | |Elemwise{sub,no_inplace} [id BD] ''
# | | | | | | | | | | |TensorConstant{1.0} [id BJ]
# | | | | | | | | | |TensorConstant{0.5} [id X]
# | | | | | | | | |Elemwise{sub,no_inplace} [id BD] ''
# | | | | | | | |Alloc [id BA] ''
# | | | | | | |Elemwise{mul,no_inplace} [id BK] ''
# | | | | | | |InplaceDimShuffle{x} [id BL] ''
# | | | | | | | |Elemwise{mul} [id U] ''
# | | | | | | |Elemwise{mul,no_inplace} [id Z] ''
# | | | | | |InplaceDimShuffle{x} [id BM] ''
# | | | | | |TensorConstant{0} [id BN]
# | | | | |InplaceDimShuffle{x} [id BO] ''
# | | | | |TensorConstant{1} [id BP]
# | | | |InplaceDimShuffle{x} [id BQ] ''
# | | | |TensorConstant{2} [id BR]
# | | |InplaceDimShuffle{x} [id BS] ''
# | | |Elemwise{log,no_inplace} [id BT] ''
# | | |TensorConstant{2.5066282746310002} [id BU]
# | |InplaceDimShuffle{x} [id BV] ''
# | |Elemwise{log,no_inplace} [id BW] ''
# | |TensorConstant{1} [id BP]
# |All [id BX] ''
# |Elemwise{gt,no_inplace} [id BY] ''
# |TensorConstant{1} [id BP]
# |TensorConstant{0.0} [id BZ]
# Elemwise{add,no_inplace} [id CA] 'new_position'
# |y[t-1] [id M]
# |Elemwise{mul,no_inplace} [id CB] ''
# |InplaceDimShuffle{x} [id CC] ''
# | |TensorConstant{0.001} [id P]
# |Elemwise{add,no_inplace} [id CD] ''
# |Elemwise{mul} [id CE] ''
# | |Elemwise{mul,no_inplace} [id CF] ''
# | | |InplaceDimShuffle{x} [id CG] ''
# | | | |Elemwise{mul} [id CH] ''
# | | | |Elemwise{second,no_inplace} [id CI] ''
# | | | | |Elemwise{mul,no_inplace} [id CJ] ''
# | | | | | |TensorConstant{0.5} [id X]
# | | | | | |dot [id CK] ''
# | | | | | |Elemwise{mul,no_inplace} [id CL] ''
# | | | | | | |Alloc [id CM] ''
# | | | | | | | |TensorConstant{1.0} [id BB]
# | | | | | | | |TensorConstant{2} [id BC]
# | | | | | | |Elemwise{sub,no_inplace} [id CN] ''
# | | | | | | |<TensorType(float64, vector)> [id BE]
# | | | | | | |Elemwise{mul,no_inplace} [id CO] ''
# | | | | | | |InplaceDimShuffle{x} [id CP] ''
# | | | | | | | |TensorConstant{0.0005} [id BH]
# | | | | | | |<TensorType(float64, vector)> [id BI]
# | | | | | |Elemwise{sub,no_inplace} [id CN] ''
# | | | | |TensorConstant{1.0} [id BJ]
# | | | |TensorConstant{0.5} [id X]
# | | |Elemwise{sub,no_inplace} [id CN] ''
# | |Alloc [id CM] ''
# |Elemwise{mul,no_inplace} [id CQ] ''
# |InplaceDimShuffle{x} [id CR] ''
# | |Elemwise{mul} [id CH] ''
# |Elemwise{mul,no_inplace} [id CL] ''
This output makes it clearer that new_position
is in the graph output by potential_fn(new_position)
, but the implication is that it's been cloned, so it's no longer identical to the original new_position
. This seems to be the underlying cause of the error.
We can look into this more from the AePPL side, but we can also perform the aesara.grad
step in a new, "atomic" variable and replace it with the original new_position
graph in new_momentum
Ok, so if I understand well the new_potential_energy
on line 66 is not linked to the new_position
defined line 64 but to a clone? I tried changing the one_step
function returned by velocity_verlet
in integrators.py
to the following:
def one_step(
position: TensorVariable,
momentum: TensorVariable,
potential_energy: TensorVariable,
potential_energy_grad: TensorVariable,
step_size: TensorVariable,
) -> IntegratorStateType:
new_momentum = momentum - b1 * step_size * potential_energy_grad
kinetic_grad = aesara.grad(kinetic_energy_fn(new_momentum), new_momentum)
new_position = position + a2 * step_size * kinetic_grad
pos = position.clone()
npe = potential_fn(pos)
npeg = aesara.grad(npe, pos)
nm = new_momentum - b1 * step_size * npeg
new_momentum = aesara.clone_replace(nm, {pos: new_position})
new_potential_energy = aesara.clone_replace(npe, {pos: new_position})
new_potential_energy_grad = aesara.clone_replace(pe, {pos: new_position})
return (
new_position,
new_momentum,
new_potential_energy,
new_potential_energy_grad,
)
And it seems to work. Is that what you meant with your last sentence? If that's the case I can merge the fix temporarily so we can close this issue, but I think it would make more sense to prevent non-atomic random variable from being cloned in aeppl
.
This fix is implemented in #26 along with the fix to handle scalar rvs.
Ok, so if I understand well the
new_potential_energy
on line 66 is not linked to thenew_position
defined line 64 but to a clone? I tried changing theone_step
function returned byvelocity_verlet
inintegrators.py
to the following:
Yes, exactly.
~I just put in a fix here: https://github.com/aesara-devs/aeppl/pull/63. I'll merge that as soon as the tests pass.~
Done.
The fix on aeppl
does the trick; I made a small change to support scalars (not python floats) as well. This should be ok now.
Merging #25 (8c2f1af) into main (d321748) will not change coverage. The diff coverage is
100.00%
.
@@ Coverage Diff @@
## main #25 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 9 9
Lines 344 344
Branches 14 14
=========================================
Hits 344 344
Impacted Files | Coverage Δ | |
---|---|---|
aehmc/metrics.py | 100.00% <100.00%> (ø) |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update d321748...8c2f1af. Read the comment docs.
Closes #16