CDCgov / multisignal-epi-inference

Python package for statistical inference and forecast of epi models using multiple signals
https://cdcgov.github.io/multisignal-epi-inference/
10 stars 1 forks source link

Refactor transform module to wrap `numpyro.distributions.transforms` #140

Closed gvegayon closed 1 month ago

gvegayon commented 1 month ago
codecov[bot] commented 1 month ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 100.00%. Comparing base (77dc819) to head (ef26272). Report is 1 commits behind head on main.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #140 +/- ## =========================================== + Coverage 92.10% 100.00% +7.89% =========================================== Files 33 2 -31 Lines 671 7 -664 =========================================== - Hits 618 7 -611 + Misses 53 0 -53 ``` | [Flag](https://app.codecov.io/gh/CDCgov/multisignal-epi-inference/pull/140/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=CDCgov) | Coverage Δ | | |---|---|---| | [unittests](https://app.codecov.io/gh/CDCgov/multisignal-epi-inference/pull/140/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=CDCgov) | `100.00% <ø> (+7.89%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=CDCgov#carryforward-flags-in-the-pull-request-comment) to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

damonbayer commented 1 month ago
  1. I think any custom transforms we want to keep should still be in their own module.
  2. Can we not just define LogTransform as ExpTransform().inv? Perhaps that doesn't get us all the functionality of a true numpyro transform (maybe it does), but it at least gives us the transform and its inv, which is all we need.
import numpyro.distributions.transforms as nT
import pyrenew.transform as pT
import jax.numpy as jnp
from numpy.testing import assert_array_almost_equal

test_array = jnp.array([0.1, 0.2, 0.4])

LogTransform = nT.ExpTransform().inv

LogTransform(test_array)
LogTransform.inv(test_array)
gvegayon commented 1 month ago
  1. I think any custom transforms we want to keep should still be in their own module.
  2. Can we not just define LogTransform as ExpTransform().inv? Perhaps that doesn't get us all the functionality of a true numpyro transform (maybe it does), but it at least gives us the transform and its inv, which is all we need.
import numpyro.distributions.transforms as nT
import pyrenew.transform as pT
import jax.numpy as jnp
from numpy.testing import assert_array_almost_equal

test_array = jnp.array([0.1, 0.2, 0.4])

LogTransform = nT.ExpTransform().inv

LogTransform(test_array)
LogTransform.inv(test_array)

I updated the code to use ExpTransform().inv. About the other transforms, maybe the scaledlogit is the one to keep? But I have my doubts as to where to put it, as it would be a module with a single function.

#140  ScaledLogitTransform
assert_array_almost_equal(
    pT.ScaledLogitTransform(4)(test_array),
    nT.ComposeTransform([nT.AffineTransform(0,1/4), nT.SigmoidTransform().inv])(test_array)
)

assert_array_almost_equal(
    pT.ScaledLogitTransform(4).inverse(test_array),
    nT.ComposeTransform([nT.AffineTransform(0,1/4), nT.SigmoidTransform().inv]).inv(test_array)
)
dylanhmorris commented 1 month ago

I updated the code to use ExpTransform().inv. About the other transforms, maybe the scaledlogit is the one to keep? But I have my doubts as to where to put it, as it would be a module with a single function.

I don't think we want to make the user remember which transforms are in numpyro.distributions.transforms and which are in pyrenew.transform. Proposal: expose transforms from numpyro.distributions.transforms to the user via pyrenew.transform:

That is: have a pyrenew/transform module that's a folder, with an __init__.py that reads something like:

from pyrenew.transform.scaledlogittransform import ScaledLogitTransform
from numpyro.distributions.transforms import ExpTransform

__all__ = [
    "ScaledLogitTransform",
    "ExpTransform"
]

Then the user can do

from pyrenew.transform import ExpTransform

Are there reasons not to do this? Documentation? Politeness/attribution? Considered an antipattern for a reason I haven't thought of?

damonbayer commented 1 month ago

@dylanhmorris I would want to ensure that all transforms from numpyro are re-exported in our module. Ideally, without having to list them all out.

dylanhmorris commented 1 month ago

This is a case in which an import * might be allowed. Something like:

pyrenew/transform/__init__.py:

from numpyro.distributions.transforms import * 
from numpyro.distributions.transforms import __all__ as numpyro_public_transforms

from pyrenew.transform.scaledlogittransform import ScaledLogitTransform

__all__ = [
    "ScaledLogitTransform"
] + numpyro_public_transforms