CDCgov / PyRenew

Python package for multi-signal Bayesian renewal modeling with JAX and NumPyro.
https://cdcgov.github.io/PyRenew/
Apache License 2.0
14 stars 3 forks source link

Tutorial / Model Demonstrating limitations of `scan` with `numpyro.plate` #444

Open damonbayer opened 1 month ago

damonbayer commented 1 month ago

Our RandomVariables which use scan cannot be used in a numpyro.plate context (see https://num.pyro.ai/en/stable/primitives.html#scan). However, we can still write reasonable looking models by using numpyro.plate for non-scanning RandomVariables and inputting the results into our vectorized scanning RandomVariables. We should have a tutorial or model to demonstrate this limitation and our workaround.

dylanhmorris commented 1 month ago

I think for deterministic scans we should probably stick with jax.lax.scan but let's discuss in https://github.com/CDCgov/PyRenew/issues/444