pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.09k stars 227 forks source link

Predictive fix when deterministic sites are present #1789

Closed kylejcaron closed 2 months ago

kylejcaron commented 2 months ago

This PR attempts to fix #1772 - when deterministic sites are included in Predictive with params, Predictive won't generate new samples for those deterministic sites.

This PR solves that by ignoring deterministic sites in the Predictive substitute call. Added a new test to cover this scenario as well.