mmcdermott / nested_ragged_tensors

Utilities for efficiently working with, saving, and loading, collections of connected nested ragged tensors in PyTorch
MIT License
7 stars 1 forks source link

You should be able to easily extract from a tensor a contiguous slice of nested values, even when that slice does not correspond to a meaningful dense slice. #34

Open Oufattole opened 2 weeks ago

Oufattole commented 2 weeks ago

The issue is that currently, when loading data for a specific patient, all data for that patient is loaded instead of just the desired subset.

Current Behavior: When loading data for a specific patient (in this case, patient 0), the entire dataset for that patient is loaded from disk, even if only a subset of the data is needed.

Desired Behavior: The ability to load only a specific slice of data for a patient, reducing the amount of data read from disk and potentially improving performance.

For example (play around with this colab to interact with the JNRT):

  1. You have a JNRT structure ehr_data containing patient data with various fields like timestamp, code, numeric_value, and doctor_note.
  2. You want to load only the codes for indexes 1:3 for patient 0.
  3. Currently, you have to load all the data for patient 0 and then slice it in memory:
fp = ??? # filepath to JNRT
st, end = 1, 3
loaded_patient = JointNestedRaggedTensorDict.load_slice(fp, 0)
np.concatenate(loaded_patient.tensors["dim1/code"][st:end], axis=0)

The issue is that JointNestedRaggedTensorDict.load_slice(fp, 0) loads all data for patient 0, not just the desired slice.

To resolve this issue, you'd want a method that allows you to specify both the patient index and the desired slice of data to load. This could potentially be implemented as a new method or an extension of the existing load_slice method, for example:

code_dim = 3
loaded_slice = JointNestedRaggedTensorDict.load_partial_slice(fp, sample_index=0, dim=code_dim, slice(1,3))

This hypothetical load_partial_slice method would only load the specified slice of data for the given patient.

mmcdermott commented 1 week ago

So, I think the problem here is as follows.

Basically, suppose I have 3D data in a JNRT of patients, who have events, which have measurements. The current slicing interface allows a subset of slice options that all have the property that either (a) the slice returns an error, or (b) the slice returns a JNRT which, when densified, conforms to the same output you would obtain were you to slice the densified version of the original JNRT with the input slice (plus or minus padding).

These capabilities include things like selecting a single patient, or selecting events 5 - 10 for a patient.

However, you might want to retrieve just measurements 5-10 for a patient, regardless of within which events those measurements occur. This slice paradigm does not have a densified analog,

You can currently do this by flattening the tensor at the last dimension, then slicing at 5-10, but this may be unnecessarily expensive.

@Oufattole, does that describe your issue, or am I misunderstanding?

Oufattole commented 1 week ago

I think the core issue is that densifying currently takes the structure of the entire JNRT and pads it. We have been assuming a 3D stored JNRT with dimensions (patients, events, measurements), and we have considered how to slice and density this into data with 2D level measurements: i.e. (patients, measurements) -- the events index is flattened out.

The current slicing paradigm, as you pointed out, conforms to what you'd get from slicing a densified version of all dimensions in the 3D data. However, the 2D-level slice we're after (measurements 5-10 across all events) doesn't have a direct densified analog over all dimensions, only across the subset, i.e. the (patients, measurements) dimensions.

To address this, I propose two options:

  1. Create separate JNRTs: We could have 3D JNRT files for models expecting 3D data (like event_stream) and 2D JNRT files for models expecting 2D data. This would require a small modification in meds-torch to properly extract the start and end times (basically just look for non_zero time deltas in the JNRT--assuming we sub-pad event times with 0s so they are the same length as the measurements--, and those indices correspond to the dates in the static_df). So the dimensions are: [patient_id, measurement], and we can use the supported JNRT slicing and densify operations.
  2. Allow more advanced indexing: This would involve creating a new function in the repo that allows us to select the start and end of the measurements dimension (3rd dim) and load just that data. We would ignore the event dimension and densify the data. This approach does not have a 3D a densified analog, since it ignores dimension 2 (events) and only has a 2D densified analog.

Both of these approaches avoid loading all patient data and will load the minimal subset of data from disk, improving file io efficiency (hopefully). I think it makes most sense to go with 1. for simplicity in the JNRT repo right? It may be a misuse of this datastructure, and further confuse users if we add non-straightforward notions of slicing the data that ignore a dimensions. Additionally I think the meds-torch modification for loading the correct st and end indexes is not too complex, and will not significantly increase the complexity of the pytorch_dataset class.

mmcdermott commented 1 week ago

So I think much of the functionality you're talking about in 2 already exists, through flatten and the current slicing infrastructure. In particular, in order to do #2, you definitely do need to load the bounds of the event-level dimension, that way you know how many measurements in total a patient has, and you also need to modify data stored at the event level, like timestamps, so that it matches up with the new 2nd dimension. The flatten function already does all of this, and can be optimized to not load anything unnecessary from disk in the pursuit of that.

That being said, just caching a 2D form from the start is likely a good idea

mmcdermott commented 1 week ago

Given that the changes from the branch that adds flatten is merged, @Oufattole can you comment at some point (not urgent) at whether or not that is sufficiently efficient for your use case?