CDCgov / PyRenew

Python package for multi-signal Bayesian renewal modeling with JAX and NumPyro.
https://cdcgov.github.io/PyRenew/
Apache License 2.0
14 stars 2 forks source link

create a censored normal distribution #428

Closed sbidari closed 1 week ago

sbidari commented 1 week ago

closes #427

codecov[bot] commented 1 week ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 93.72%. Comparing base (3c5fbe7) to head (f677032). Report is 1 commits behind head on main.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #428 +/- ## ========================================== + Coverage 93.51% 93.72% +0.20% ========================================== Files 41 43 +2 Lines 1018 1052 +34 ========================================== + Hits 952 986 +34 Misses 66 66 ``` | [Flag](https://app.codecov.io/gh/CDCgov/PyRenew/pull/428/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=CDCgov) | Coverage Δ | | |---|---|---| | [unittests](https://app.codecov.io/gh/CDCgov/PyRenew/pull/428/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=CDCgov) | `93.72% <100.00%> (+0.20%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=CDCgov#carryforward-flags-in-the-pull-request-comment) to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

SamuelBrand1 commented 1 week ago

I think its worth reusing existing jax functions where possible for reducing code length and maybe getting optimization hacks.

Is it not possible to reuse https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.truncnorm.logpdf.html and https://jax.readthedocs.io/en/latest/_autosummary/jax.random.truncated_normal.html#jax.random.truncated_normal ?

SamuelBrand1 commented 1 week ago

Ooops! Reread the issue, you are indeed doing censoring!

Its a bit confusing because you're doing truncated sampling within a censoring context. So maybe this should be a truncated normal distribution description despite ultimately being used to deal with the problems induced by censored observation?

damonbayer commented 1 week ago

Does this differ in any substantial way from the example in the NumPyro docs: https://num.pyro.ai/en/latest/tutorials/censoring.html ? I have not reviewed this PR or the doc yet, but am curious if there are any major differences.

dylanhmorris commented 1 week ago

Does this differ in any substantial way from the example in the NumPyro docs: https://num.pyro.ai/en/latest/tutorials/censoring.html ? I have not reviewed this PR or the doc yet, but am curious if there are any major differences.

  1. Docs do it as a modeling problem using only bundled distributions; this creates a numpyro.distributions.Distribution subclass to handle it directly. I strongly favor the second approach where possible for modularity / texting / reusability.
  2. Does not use a boolean array to indicate which values are censored and instead treats any values at or outside the given censoring bounds as censored.

Would be a nice feature for numpyro to have semi-automated creation of censored distributions from base distributions (as it currently has has for truncated distributions). But as far as I can see that does not yet exist.