MultiWindow dataset and RandomWindows dataset classes manage window data as simple dictionaries
Base PytorchDataset and random windows dataset classes also returns simple dictionaries
Data management and window fusion is done with raw tensors in each of these separate classes
Static data handling during fusion is inefficient
Proposed Changes:
ESGPT uses this pytorch batch class. This approach is nice because it allows us to pack functionalities for manipulating windows of data, and converting data into different formats (such as into a meds like format) into class functions. Additionally we can add optional keys that get filled in by models, so this would clean up the manual use of sentinel keys in other places in the dataset, and enable type checking for those keys. We can even require models to have validation functions before returning their batch dataclass ensuring the prior stages added the appropriate sentinel keys.
We need to add a base Window class that the pytorch dataset class uses that just has the expected keys output by the pytorch dataset class
@dataclasses.dataclass
class WindowBatch:
"""A dataclass representing a batch of temporal data for a Pytorch model.
Attributes:
code: A long tensor of shape (batch_size, sequence_length) containing categorical codes/indices
mask: A boolean tensor of shape (batch_size, sequence_length) indicating which elements are valid (not padding)
numeric_value: A float tensor of shape (batch_size, sequence_length) containing numeric values
numeric_value_mask: A boolean tensor of shape (batch_size, sequence_length) indicating which numeric values were observed
prepended_static_mask: A boolean tensor of shape (batch_size, sequence_length) indicating which elements are prepended static data
time_delta_days: A float tensor of shape (batch_size, sequence_length) indicating days between events
static_indices: A long tensor of shape (batch_size, n_static) containing static categorical codes
static_values: A float tensor of shape (batch_size, n_static) containing static numeric values
static_mask: A boolean tensor of shape (batch_size, n_static) indicating which static elements are valid
static_numeric_value_mask: A boolean tensor of shape (batch_size, n_static) indicating which static numeric values are valid
label: A float tensor of shape (batch_size,) containing classification labels
"""
code: torch.LongTensor
mask: torch.BoolTensor
numeric_value: torch.FloatTensor
numeric_value_mask: torch.BoolTensor
prepended_static_mask: torch.BoolTensor # Added this field
time_delta_days: torch.FloatTensor
static_indices: torch.LongTensor | None = None
static_values: torch.FloatTensor | None = None
static_mask: torch.BoolTensor | None = None
static_numeric_value_mask: torch.BoolTensor | None = None
label: torch.FloatTensor | None = None
Create WindowBatch Class for the MultiWindow and RandomWindow datasets:
Implementing a tokenization method specific to_meds function. I added one to the eic_forecasting model and will one to the triplet model soon: #107. We should centralize these in the Window classes I think.
This will give us:
Type safety through dataclass structures
Memory efficient early window fusion by avoiding static data duplication
Consistent tensor management across codebase
Cleaner API for window operations
Implementation Steps:
[ ] Create Window class
[ ] Create MultiWindow class (with window fusion logic)
[ ] Update the dataset classes and models to use these classes instead of dictionary batches
Current Status:
RandomWindows
dataset classes manage window data as simple dictionariesProposed Changes: ESGPT uses this pytorch batch class. This approach is nice because it allows us to pack functionalities for manipulating windows of data, and converting data into different formats (such as into a meds like format) into class functions. Additionally we can add optional keys that get filled in by models, so this would clean up the manual use of sentinel keys in other places in the dataset, and enable type checking for those keys. We can even require models to have validation functions before returning their batch dataclass ensuring the prior stages added the appropriate sentinel keys.
We need to add a base
Window
class that the pytorch dataset class uses that just has the expected keys output by the pytorch dataset classCreate WindowBatch Class for the MultiWindow and RandomWindow datasets:
Implementing a tokenization method specific
to_meds
function. I added one to the eic_forecasting model and will one to the triplet model soon: #107. We should centralize these in the Window classes I think.This will give us:
Implementation Steps:
Does this look right @mmcdermott ?