This update makes the assert_within_limits method of the Prior class play nicely with JAX. As the value passed into the function will be a JAX traced array a jax.lax.cond must be used in place of a traditional if block.
As one branch will raise an exception it must be wrapped in a jax.debug.callback() to ensure it does not evaluate until the jited function exits.
Also the jax.numpy.logical_or is used in the condition definition to make it work on the traced array as expected.
Side note: the if jax_wrapper.use_jax: block does not need this kind of treatment as it is a static value when the function is jited and the compiler knows what branch to take.
This update makes the
assert_within_limits
method of thePrior
class play nicely with JAX. As thevalue
passed into the function will be a JAX traced array ajax.lax.cond
must be used in place of a traditionalif
block.As one branch will
raise
an exception it must be wrapped in ajax.debug.callback()
to ensure it does not evaluate until thejit
ed function exits.Also the
jax.numpy.logical_or
is used in the condition definition to make it work on the traced array as expected.Side note: the
if jax_wrapper.use_jax:
block does not need this kind of treatment as it is a static value when the function isjit
ed and the compiler knows what branch to take.