CDCgov / multisignal-epi-inference

Python package for statistical inference and forecast of epi models using multiple signals
https://cdcgov.github.io/multisignal-epi-inference/
9 stars 1 forks source link

Time handling (related to data class) #185

Closed gvegayon closed 4 days ago

gvegayon commented 2 weeks ago

As models become more complex, we need a way to keep track of time units. Whatever solution we end up implementing, it should address the following issues (from internal notes):

This is closely related to #40.

Update June 26

This issue led to the following two issues:

gvegayon commented 1 week ago

Some thoughts about this issue:

  1. Part of the solution could be including a clock() member for data arrays. In principle, it could be a soft wrapper of jax.numpy.array that contains information about (a) the relative time unit, and (b) the time unit of the first element of the array. co-pilot hints we can achieve this using delegation; here is a code snippet illustrating this:

    import jax.numpy as jnp
    
    class PyRenewArray:
        def __init__(self, data, time_unit=None, time_reference=None):
            self.data = jnp.array(data)
            self.time_unit = time_unit
            self.time_reference = time_reference
    
        def __getattr__(self, attr):
            # delegate attribute access to the underlying jax.numpy array
            return getattr(self.data, attr)
    
    x = PyRenewArray([1, 2, 3])
    print(x.mean())  # calls jnp.mean(x.data)
  2. Either PyRenewArray has a method, or there's a function that allows syncing arrays; for instance, in the case of latent.Infections(), the sampling function could take all arguments (Rt, I0, gen_time) and sync them before doing anything:

    def sample(Rt, I0, gen_time):
    
        # Make sure all are in the same time unit
        Rt, I0, gen_time = sync_arrays(Rt, I0, gen_time, unit = "days")

    syn_arrays() would get the time_unit and time_reference from the PyRenewArray attributes to:

    • Padded left to match the first array starting point.
    • Broadcast/interpolated/aggregated to daily units.
    • Padded right (?) to match the longest time/series.
    • Etc.

    In other words, make sure everything has the same length and has matching days (e.g., all start on Wednesday, June 19th, if that's the reference date Rt has).

damonbayer commented 1 week ago

Re: time reference, one challenge I foresee is slicing arrays. When we pass x[3:] somewhere, how could we indicate it has a new time reference that is different from x?

gvegayon commented 1 week ago

Great point! My first impression is you'd have to implement an indexing method that is aware of time (not ideal). Also, as I'm digging deeper into it, it is starting to look more complicated than I'd like. Trying out using delegation doesn't make jax happy. Numpy has a formal, well-defined way of subclassing, but it doesn't work with Jax arrays (and they don't recommend it).

gvegayon commented 1 week ago

The closest I got is using views to cast between types. Here is some code:

import jax.numpy as jnp
import numpy as np

class PyRenewArray(np.ndarray):
    _pyrenew_attr : float = None

x = jnp.array([1, 2, 3])

np.asarray(x).view(PyRenewArray)

And prints:

PyRenewArray([1, 2, 3], dtype=int32)

But this breaks as soon as Traced arrays starts working. So wrapping seems to be a no go.

dylanhmorris commented 1 week ago

I'd be interested in exploring how they do things in Rockpool

gvegayon commented 6 days ago

I'd be interested in exploring how they do things in Rockpool

What I like:

Questions:

It'd be great to see some examples dealing with the problems we are trying to deal with.

gvegayon commented 4 days ago

After this morning's meeting, we have agreed on the following:

  1. The Rockpool package will be considered later on to provide users with smooth pre- and post-processing tooling.

  2. We agreed that having RandomVariables carrying information about time is a good solution, regardless of whether the variable needs to be time-aware.

  3. As a minimum and first implementation, all instances of RandomVariable will have the following two attributes: timeseries_start : int = None and timeseries_unit : int = 1.

  4. This will lead to two PRs:

    • [ ] Incorporating these new attributes in all instances of RandomVariable, adding the corresponding values of _start and _unit when needed (e.g., the weekly effects should have timeseries_unit = 7).
    • [ ] Using the new attributes, remove padding whenever possible (hopefully everywhere).

These two will be assigned to me and hopefully resolved during this sprint. Nonetheless, I believe the padding may stretch until the next sprint. Agree @dylanhmorris @damonbayer @AFg6K7h4fhy2?

gvegayon commented 4 days ago

I'm closing this for the moment. We can continue discussing details in the two new issues (which I'll link).