Closed gvegayon closed 4 days ago
Some thoughts about this issue:
Part of the solution could be including a clock()
member for data arrays. In principle, it could be a soft wrapper of jax.numpy.array
that contains information about (a) the relative time unit, and (b) the time unit of the first element of the array. co-pilot hints we can achieve this using delegation; here is a code snippet illustrating this:
import jax.numpy as jnp
class PyRenewArray:
def __init__(self, data, time_unit=None, time_reference=None):
self.data = jnp.array(data)
self.time_unit = time_unit
self.time_reference = time_reference
def __getattr__(self, attr):
# delegate attribute access to the underlying jax.numpy array
return getattr(self.data, attr)
x = PyRenewArray([1, 2, 3])
print(x.mean()) # calls jnp.mean(x.data)
Either PyRenewArray
has a method, or there's a function that allows syncing arrays; for instance, in the case of latent.Infections()
, the sampling function could take all arguments (Rt
, I0
, gen_time
) and sync them before doing anything:
def sample(Rt, I0, gen_time):
# Make sure all are in the same time unit
Rt, I0, gen_time = sync_arrays(Rt, I0, gen_time, unit = "days")
syn_arrays()
would get the time_unit
and time_reference
from the PyRenewArray
attributes to:
In other words, make sure everything has the same length and has matching days (e.g., all start on Wednesday, June 19th, if that's the reference date Rt has).
Re: time reference, one challenge I foresee is slicing arrays. When we pass x[3:]
somewhere, how could we indicate it has a new time reference that is different from x
?
Great point! My first impression is you'd have to implement an indexing method that is aware of time (not ideal). Also, as I'm digging deeper into it, it is starting to look more complicated than I'd like. Trying out using delegation doesn't make jax
happy. Numpy has a formal, well-defined way of subclassing, but it doesn't work with Jax arrays (and they don't recommend it).
The closest I got is using views to cast between types. Here is some code:
import jax.numpy as jnp
import numpy as np
class PyRenewArray(np.ndarray):
_pyrenew_attr : float = None
x = jnp.array([1, 2, 3])
np.asarray(x).view(PyRenewArray)
And prints:
PyRenewArray([1, 2, 3], dtype=int32)
But this breaks as soon as Traced arrays starts working. So wrapping seems to be a no go.
I'd be interested in exploring how they do things in Rockpool
I'd be interested in exploring how they do things in
Rockpool
What I like:
Questions:
pyrenew
? If so, how much weight will this add to the project? It'd be great to see some examples dealing with the problems we are trying to deal with.
After this morning's meeting, we have agreed on the following:
The Rockpool
package will be considered later on to provide users with smooth pre- and post-processing tooling.
We agreed that having RandomVariable
s carrying information about time is a good solution, regardless of whether the variable needs to be time-aware.
As a minimum and first implementation, all instances of RandomVariable
will have the following two attributes: timeseries_start : int = None
and timeseries_unit : int = 1
.
This will lead to two PRs:
RandomVariable
, adding the corresponding values of _start
and _unit
when needed (e.g., the weekly effects should have timeseries_unit = 7
).These two will be assigned to me and hopefully resolved during this sprint. Nonetheless, I believe the padding may stretch until the next sprint. Agree @dylanhmorris @damonbayer @AFg6K7h4fhy2?
I'm closing this for the moment. We can continue discussing details in the two new issues (which I'll link).
As models become more complex, we need a way to keep track of time units. Whatever solution we end up implementing, it should address the following issues (from internal notes):
[0, 1, 2] -> [2024-01-01, 2024-01-02, 2024-01-03]
[1,2,3] - > [1,1,1,1,1,1,1,2,2,2,2,2,2,2,3,3,3,3,3,3,3]
[1, 2, 3] -> [NA, NA, 1, 2, 3]
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] -> [28, 77]
This is closely related to #40.
Update June 26
This issue led to the following two issues: