CDCgov / multisignal-epi-inference

Python package for statistical inference and forecast of epi models using multiple signals
https://cdcgov.github.io/multisignal-epi-inference/
10 stars 1 forks source link

Abandon or rewrite custom Transform class in favor of `numpyro.distributions.transforms` #128

Closed damonbayer closed 1 month ago

damonbayer commented 1 month ago

I don't think there is a good reason to roll our own transform class when numpyro already provides one. All of our transforms are already implemented or can easily be implemented with ComposeTransform.

import numpyro.distributions.transforms as nT
import pyrenew.transform as pT
import jax.numpy as jnp
from numpy.testing import assert_array_almost_equal

test_array = jnp.array([0.1, 0.2, 0.4])

assert_array_almost_equal(
    nT.ExpTransform().inv(test_array),
    pT.LogTransform()(test_array)
)

# IdentityTransform
assert_array_almost_equal(
    pT.IdentityTransform()(test_array),
    nT.IdentityTransform()(test_array)
)

assert_array_almost_equal(
    pT.IdentityTransform().inverse(test_array),
    nT.IdentityTransform().inv(test_array)
)

# LogTransform
assert_array_almost_equal(
    pT.LogTransform()(test_array),
    nT.ExpTransform().inv(test_array)
)

assert_array_almost_equal(
    pT.LogTransform().inverse(test_array),
    nT.ExpTransform()(test_array)
)

# LogitTransform
assert_array_almost_equal(
    pT.LogitTransform()(test_array),
    nT.SigmoidTransform().inv(test_array)
)

assert_array_almost_equal(
    pT.LogitTransform().inverse(test_array),
    nT.SigmoidTransform()(test_array)
)

# ScaledLogitTransform
assert_array_almost_equal(
    pT.ScaledLogitTransform(4)(test_array),
    nT.ComposeTransform([nT.AffineTransform(0,1/4), nT.SigmoidTransform().inv])(test_array)
)

assert_array_almost_equal(
    pT.ScaledLogitTransform(4).inverse(test_array),
    nT.ComposeTransform([nT.AffineTransform(0,1/4), nT.SigmoidTransform().inv]).inv(test_array)
)
dylanhmorris commented 1 month ago

Noting the sync discussion, in which we agreed (I believe) to:

  1. Expose all public transforms in numpyro.distributions.transforms to the user via the pyrenew.transform module
  2. Construct any addition transforms pyrenew provides (e.g. ScaledLogit) from numpyro transforms using numpyro.distributions.transforms.ComposeTransform and expose them to the user via the same pyrenew.transform module.