aesara-devs / aehmc

An HMC/NUTS implementation in Aesara
MIT License
33 stars 6 forks source link

NUTS kernel fails with large log-probability values and step size #75

Closed rlouf closed 2 years ago

rlouf commented 2 years ago
import aesara
import aesara.tensor as at
from aesara.tensor.random import RandomStream
from aeppl import joint_logprob
from aehmc import nuts

srng = at.random.RandomStream(seed=0)
Y_rv = srng.normal(1, 2)

def logprob_fn(y):
    logprob = 1e20 * joint_logprob({Y_rv: y})
    return logprob

y_vv = Y_rv.clone()
kernel = nuts.new_kernel(srng, logprob_fn)
initial_state = nuts.new_state(y_vv, logprob_fn)

params = (at.scalar(), at.scalar())
new_state, updates = kernel(*initial_state, *params)
nuts_step_fn = aesara.function(
    (y_vv, *params), new_state, updates=updates
)

step_size = 1.
inverse_mass_matrix = 1.
print(nuts_step_fn(100., step_size, inverse_mass_matrix))
# [array(100.), array(1.22673709e+23), array(2.475e+21), array(0.), array(1), array(True), array(True)]

step_size = 1e40
inverse_mass_matrix = 1.
print(nuts_step_fn(100., step_size, inverse_mass_matrix))
# Exception due to `p_accept` evaluating to `NaN`

p_accept should never evaluate to NaN in this case. Instead, we expect the kernel to return the initial state and is_diverging=True:

step_size = 1e40
inverse_mass_matrix = 1.
print(nuts_step_fn(100., step_size, inverse_mass_matrix))
# [array(100.), array(1.22673709e+23), array(2.475e+21), array(0.), array(1), array(True), array(True)]

I currently suspect (but this needs to be confirmed) that the trajectory builder does not return immediately when the first step diverges, leading to a at.exp(np.inf - np.inf) operation which returns NaN.