mmcdermott / nested_ragged_tensors

Utilities for efficiently working with, saving, and loading, collections of connected nested ragged tensors in PyTorch
MIT License
7 stars 1 forks source link

`load` and `load_slice` should have options to flatten the data in the last level without padding during loading. #8

Closed mmcdermott closed 1 week ago

mmcdermott commented 2 months ago

This should happen during loading b/c the np.split call is the only thing that would change in that context, so it would be a very low cost operation and would enable seamless switching between per-measurement or per-event tokenization. Tagging @Oufattole and @prenc for tracking

Oufattole commented 2 months ago

For load_slice, is J_FLAT what the flattened version of J should be? @mmcdermott

            >>> J = JointNestedRaggedTensorDict({
            ...     "T":   [[1,           2,        3       ], [4,   5          ]],
            ...     "id":  [[[1, 2,   3], [3,   4], [1, 2  ]], [[3], [3,   2, 2]]],
            ...     "val": [[[1, 0.2, 0], [3.1, 0], [1, 2.2]], [[3], [3.3, 2, 0]]],
            ... })
            >>> J_FLAT = JointNestedRaggedTensorDict({
            ...     "T":   [[1,           2,        3       ], [4,   5          ]],
            ...     "id":  [[1, 2,   3, 3,   4, 1, 2  ], [3, 3,   2, 2]],
            ...     "val": [[1, 0.2, 0, 3.1, 0, 1, 2.2], [3, 3.3, 2, 0]],
            ... })
mmcdermott commented 2 months ago

hmm. this actually raises a good point @Oufattole. Because no, that isn't right -- that doesn't meet the criteria for a NRT, which is that every list needs to have the same length as every other list within a given level. E.g., J["T"][0] is of length 3, so J["id"][0] must also be of length 3, but it isn't.

I think what we'd need is:

>>> J = JointNestedRaggedTensorDict({
...     "T":   [[1,           2,        3       ], [4,   5          ]],
...     "id":  [[[1, 2,   3], [3,   4], [1, 2  ]], [[3], [3,   2, 2]]],
...     "val": [[[1, 0.2, 0], [3.1, 0], [1, 2.2]], [[3], [3.3, 2, 0]]],
... })
>>> J_FLAT = JointNestedRaggedTensorDict({
...     "T":   [[1,  1,  1, 2,   2,  3, 3  ], [4,   5, 5, 5     ]],
...     "id":  [[1, 2,   3, 3,   4, 1, 2  ], [3, 3,   2, 2]],
...     "val": [[1, 0.2, 0, 3.1, 0, 1, 2.2], [3, 3.3, 2, 0]],
... })

Or we'd need to drop T first in order to use this mode, which might be what I'd advocate for. Or maybe instead of replicating the values of T add zeros.

Oufattole commented 2 months ago

Although dropping T is simpler, that would correspond to dropping all time_deltas for event streams. The other option of duplicating is also problematic as then we would probably also want to track a mask of some sort to mask out the repeated time_deltas. Would the following approach be appropriate:

  1. Add a flatten_last_dim parameter to the __init__, load, and load_slice methods.
  2. Modify the _initialize_tensors method to handle flattened data when flatten_last_dim is True. Allowing for there to be different lengths in the Joint Nested Ragged Tensor when loading with the flatten_last_dim=True
  3. Update the to_dense method to handle flattened data correctly.
mmcdermott commented 2 months ago

Dropping all things not at the max dim is the right approach for now (either in NRT or having NRT raise an error and having them be dropped prior to being saved).

For actually doing the flattening the logic looks something like this:

>>> L =  [[[1, 2,   3], [3,   4], [1, 2  ]], [[3], [3,   2, 2]]]
>>> flat_L = np.array([1, 2, 3, 3, 4, 1, 2, 3, 3, 2, 2])
>>> bounds_0 = [3, 5]
>>> bounds_1 = [3, 5, 7, 8, 11]
>>> np.split(flat_L, bounds_1[:-1])
[array([1, 2, 3]), array([3, 4]), array([1, 2]), array([3]), array([3, 2, 2])]
>>> np.split(flat_L, [max(x) for x in np.split(bounds_1, bounds_0[:-1])][:-1])
[array([1, 2, 3, 3, 4, 1, 2]), array([3, 3, 2, 2])]

Here is where you see these np.split(values, bounds[:-1] calls in load_slice: https://github.com/mmcdermott/nested_ragged_tensors/blob/main/src/nested_ragged_tensors/ragged_numpy.py#L1032C2-L1032C61