Closed rlouf closed 2 years ago
Merging #76 (29418b6) into main (e6498d4) will not change coverage. The diff coverage is
100.00%
.
@@ Coverage Diff @@
## main #76 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 13 13
Lines 541 545 +4
Branches 31 32 +1
=========================================
+ Hits 541 545 +4
Impacted Files | Coverage Δ | |
---|---|---|
aehmc/proposals.py | 100.00% <100.00%> (ø) |
|
aehmc/trajectory.py | 100.00% <100.00%> (ø) |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update e6498d4...29418b6. Read the comment docs.
The initial state is now indeed returned when the first transition is divergent (see the modification we needed to make in test_trajectory.py
), but the following code still fails when exectuted
import numpy as np
import aesara
import aesara.tensor as at
from aesara.tensor.random import RandomStream
from aeppl import joint_logprob
from aehmc import nuts
srng = at.random.RandomStream(seed=0)
Y_rv = srng.normal(1, 2)
def logprob_fn(y):
logprob = 1e20 * joint_logprob({Y_rv: y})
return logprob
y_vv = Y_rv.clone()
kernel = nuts.new_kernel(srng, logprob_fn)
initial_state = nuts.new_state(y_vv, logprob_fn)
params = (at.scalar(), at.scalar())
new_state, updates = kernel(*initial_state, *params)
nuts_step_fn = aesara.function(
(y_vv, *params),
new_state,
updates=updates,
mode=aesara.compile.mode.Mode(linker='cvm')
)
step_size = 1e40
inverse_mass_matrix = 1.
new_state = nuts_step_fn(1., step_size, inverse_mass_matrix)
with the same error described in #76. I suspect that the compiled function evaluates both IfElse
branches.
It looks like my suspicion is correct; the following code fails as well:
import aesara
import aesara.tensor as at
from aesara.ifelse import ifelse
import numpy as np
srng = at.random.RandomStream(0)
true_res = at.scalar(dtype=np.int64)
a = at.as_tensor(1)
b = at.scalar()
p = a / b
false_res = srng.bernoulli(p)
cond = np.array(True)
result = ifelse(cond, true_res, false_res)
result_fn = aesara.function((true_res, b), result)
result_fn(1, 0)
with the following trace:
Traceback (most recent call last):
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/compile/function/types.py", line 975, in __call__
self.vm()
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/graph/op.py", line 543, in rval
r = p(n, [x[0] for x in i], o)
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/tensor/random/op.py", line 368, in perform
smpl_val = self.rng_fn(rng, *(args + [size]))
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/tensor/random/basic.py", line 55, in rng_fn
res = cls.rng_fn_scipy(*args, **kwargs)
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/tensor/random/basic.py", line 529, in rng_fn_scipy
return stats.bernoulli.rvs(p, size=size, random_state=rng)
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py", line 3151, in rvs
return super().rvs(*args, **kwargs)
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py", line 1068, in rvs
raise ValueError("Domain error in arguments.")
ValueError: Domain error in arguments.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 20, in <module>
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/compile/function/types.py", line 988, in __call__
raise_with_op(
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/link/utils.py", line 534, in raise_with_op
raise exc_value.with_traceback(exc_trace)
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/compile/function/types.py", line 975, in __call__
self.vm()
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/graph/op.py", line 543, in rval
r = p(n, [x[0] for x in i], o)
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/tensor/random/op.py", line 368, in perform
smpl_val = self.rng_fn(rng, *(args + [size]))
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/tensor/random/basic.py", line 55, in rng_fn
res = cls.rng_fn_scipy(*args, **kwargs)
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/tensor/random/basic.py", line 529, in rng_fn_scipy
return stats.bernoulli.rvs(p, size=size, random_state=rng)
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py", line 3151, in rvs
return super().rvs(*args, **kwargs)
File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py", line 1068, in rvs
raise ValueError("Domain error in arguments.")
ValueError: Domain error in arguments.
Apply node that caused the error: bernoulli_rv{0, (0,), int64, True}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F74A77FB5A0>), TensorConstant{[]}, TensorConstant{4}, Elemwise{reciprocal,no_inplace}.0)
Toposort index: 1
Inputs types: [RandomGeneratorType, TensorType(int64, (0,)), TensorType(int64, ()), TensorType(float64, ())]
Inputs shapes: ['No shapes', (0,), (), ()]
Inputs strides: ['No strides', (0,), (), ()]
Inputs values: [Generator(PCG64) at 0x7F74A77FB5A0, array([], dtype=int64), array(4), array(inf)]
Outputs clients: [['output'], [if{inplace}(TensorConstant{True}, <TensorType(int64, ())>, bernoulli_rv{0, (0,), int64, True}.out)]]
HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
which means that the graph corresponding to the else
branch is evaluated even though the condition is always True
. It fails even when optimizer
is set to None
.
Since both branches are evaluated I added a (temporary, hopefully) check that sets the value of p_accept
to 0
when it is NaN
so the fix can be merged quickly. I also added a new test case.
The NUTS kernel fails for a combination of large step size and large logprob values (#75). This happens because
AeHMC
will attempt to build the trajectory even if the first step was diverging, and try to subtract twonp.inf
values. In this PR we skip trajectory building if the first transition is divergent.Closes #75