The truncated distributions on NumPyro at present operate in lazy mode, where cached JAX tracer arrays are used but not recomputed in sampling. This mode works for posterior sampling but not for forecasting, which requires recomputation of these arrays during use of numpyro.infer.Predictive().
The following PR by DHM and SB tackles this (the author of this issue has not looked at how this was done) but the changes are not yet present in NumPyro:
This Issue exists to ensure there is vigilance regarding the merging of this PR into NumPyro, so a truncated Normal distribution can be used for the initial susceptibility prior in pyrenew-flu-light.
PyRenew main branch now uses the dev version of numpyro, so you can now use truncated distributions if you run poetry update pyrenew locally. See https://github.com/CDCgov/PyRenew/pull/421
The truncated distributions on NumPyro at present operate in lazy mode, where cached JAX tracer arrays are used but not recomputed in sampling. This mode works for posterior sampling but not for forecasting, which requires recomputation of these arrays during use of
numpyro.infer.Predictive()
.The following PR by DHM and SB tackles this (the author of this issue has not looked at how this was done) but the changes are not yet present in NumPyro:
This Issue exists to ensure there is vigilance regarding the merging of this PR into NumPyro, so a truncated Normal distribution can be used for the initial susceptibility prior in
pyrenew-flu-light
.