rhayes777 / PyAutoFit

PyAutoFit: Classy Probabilistic Programming
https://pyautofit.readthedocs.io/
MIT License
59 stars 11 forks source link

Update `assert_within_limits` on the `Prior` class #1020

Closed CKrawczyk closed 3 months ago

CKrawczyk commented 3 months ago

This update makes the assert_within_limits method of the Prior class play nicely with JAX. As the value passed into the function will be a JAX traced array a jax.lax.cond must be used in place of a traditional if block.

As one branch will raise an exception it must be wrapped in a jax.debug.callback() to ensure it does not evaluate until the jited function exits.

Also the jax.numpy.logical_or is used in the condition definition to make it work on the traced array as expected.

Side note: the if jax_wrapper.use_jax: block does not need this kind of treatment as it is a static value when the function is jited and the compiler knows what branch to take.