aesara-devs / aehmc

An HMC/NUTS implementation in Aesara
MIT License
33 stars 6 forks source link

Skip trajectory building if the first step diverges #76

Closed rlouf closed 2 years ago

rlouf commented 2 years ago

The NUTS kernel fails for a combination of large step size and large logprob values (#75). This happens because AeHMC will attempt to build the trajectory even if the first step was diverging, and try to subtract two np.inf values. In this PR we skip trajectory building if the first transition is divergent.

Closes #75

codecov[bot] commented 2 years ago

Codecov Report

Merging #76 (29418b6) into main (e6498d4) will not change coverage. The diff coverage is 100.00%.

@@            Coverage Diff            @@
##              main       #76   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files           13        13           
  Lines          541       545    +4     
  Branches        31        32    +1     
=========================================
+ Hits           541       545    +4     
Impacted Files Coverage Δ
aehmc/proposals.py 100.00% <100.00%> (ø)
aehmc/trajectory.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update e6498d4...29418b6. Read the comment docs.

rlouf commented 2 years ago

The initial state is now indeed returned when the first transition is divergent (see the modification we needed to make in test_trajectory.py), but the following code still fails when exectuted

import numpy as np

import aesara
import aesara.tensor as at
from aesara.tensor.random import RandomStream
from aeppl import joint_logprob
from aehmc import nuts

srng = at.random.RandomStream(seed=0)
Y_rv = srng.normal(1, 2)

def logprob_fn(y):
    logprob = 1e20 * joint_logprob({Y_rv: y})
    return logprob

y_vv = Y_rv.clone()
kernel = nuts.new_kernel(srng, logprob_fn)
initial_state = nuts.new_state(y_vv, logprob_fn)

params = (at.scalar(), at.scalar())
new_state, updates = kernel(*initial_state, *params)
nuts_step_fn = aesara.function(
    (y_vv, *params),
    new_state,
    updates=updates,
    mode=aesara.compile.mode.Mode(linker='cvm')
)

step_size = 1e40
inverse_mass_matrix = 1.
new_state = nuts_step_fn(1., step_size, inverse_mass_matrix)

with the same error described in #76. I suspect that the compiled function evaluates both IfElse branches.

rlouf commented 2 years ago

It looks like my suspicion is correct; the following code fails as well:

import aesara
import aesara.tensor as at
from aesara.ifelse import ifelse

import numpy as np

srng = at.random.RandomStream(0)

true_res = at.scalar(dtype=np.int64)

a = at.as_tensor(1)
b = at.scalar()
p = a / b
false_res = srng.bernoulli(p)

cond = np.array(True)
result = ifelse(cond, true_res, false_res)

result_fn = aesara.function((true_res, b), result)
result_fn(1, 0)

with the following trace:

Traceback (most recent call last):
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/compile/function/types.py", line 975, in __call__
    self.vm()
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/graph/op.py", line 543, in rval
    r = p(n, [x[0] for x in i], o)
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/tensor/random/op.py", line 368, in perform
    smpl_val = self.rng_fn(rng, *(args + [size]))
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/tensor/random/basic.py", line 55, in rng_fn
    res = cls.rng_fn_scipy(*args, **kwargs)
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/tensor/random/basic.py", line 529, in rng_fn_scipy
    return stats.bernoulli.rvs(p, size=size, random_state=rng)
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py", line 3151, in rvs
    return super().rvs(*args, **kwargs)
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py", line 1068, in rvs
    raise ValueError("Domain error in arguments.")
ValueError: Domain error in arguments.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 20, in <module>
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/compile/function/types.py", line 988, in __call__
    raise_with_op(
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/link/utils.py", line 534, in raise_with_op
    raise exc_value.with_traceback(exc_trace)
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/compile/function/types.py", line 975, in __call__
    self.vm()
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/graph/op.py", line 543, in rval
    r = p(n, [x[0] for x in i], o)
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/tensor/random/op.py", line 368, in perform
    smpl_val = self.rng_fn(rng, *(args + [size]))
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/tensor/random/basic.py", line 55, in rng_fn
    res = cls.rng_fn_scipy(*args, **kwargs)
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/aesara/tensor/random/basic.py", line 529, in rng_fn_scipy
    return stats.bernoulli.rvs(p, size=size, random_state=rng)
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py", line 3151, in rvs
    return super().rvs(*args, **kwargs)
  File "/home/remi/.conda/envs/aehmc-dev/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py", line 1068, in rvs
    raise ValueError("Domain error in arguments.")
ValueError: Domain error in arguments.
Apply node that caused the error: bernoulli_rv{0, (0,), int64, True}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F74A77FB5A0>), TensorConstant{[]}, TensorConstant{4}, Elemwise{reciprocal,no_inplace}.0)
Toposort index: 1
Inputs types: [RandomGeneratorType, TensorType(int64, (0,)), TensorType(int64, ()), TensorType(float64, ())]
Inputs shapes: ['No shapes', (0,), (), ()]
Inputs strides: ['No strides', (0,), (), ()]
Inputs values: [Generator(PCG64) at 0x7F74A77FB5A0, array([], dtype=int64), array(4), array(inf)]
Outputs clients: [['output'], [if{inplace}(TensorConstant{True}, <TensorType(int64, ())>, bernoulli_rv{0, (0,), int64, True}.out)]]

HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

which means that the graph corresponding to the else branch is evaluated even though the condition is always True. It fails even when optimizer is set to None.

rlouf commented 2 years ago

Since both branches are evaluated I added a (temporary, hopefully) check that sets the value of p_accept to 0 when it is NaN so the fix can be merged quickly. I also added a new test case.