To use the JAX backend directly it's required that the boolean mask + advanced set subtensor in the graph of PartiallyObservedRVs get constant folded.
When using default mutable data/coords, this will not be the case, and the recommended solution is for users to make use of freeze_rv_and_dims. We should add a test that this actually works as suggested by @jessegrabowski
To use the JAX backend directly it's required that the boolean mask + advanced set subtensor in the graph of
PartiallyObservedRV
s get constant folded.When using default mutable data/coords, this will not be the case, and the recommended solution is for users to make use of
freeze_rv_and_dims
. We should add a test that this actually works as suggested by @jessegrabowski