Oufattole / meds-torch

MIT License
16 stars 2 forks source link

Refactor PytorchDataset class and other dataset classes to a `Window` dataclass for batches instead of a dictionary #126

Open Oufattole opened 2 weeks ago

Oufattole commented 2 weeks ago

Current Status:

Proposed Changes: ESGPT uses this pytorch batch class. This approach is nice because it allows us to pack functionalities for manipulating windows of data, and converting data into different formats (such as into a meds like format) into class functions. Additionally we can add optional keys that get filled in by models, so this would clean up the manual use of sentinel keys in other places in the dataset, and enable type checking for those keys. We can even require models to have validation functions before returning their batch dataclass ensuring the prior stages added the appropriate sentinel keys.

  1. We need to add a base Window class that the pytorch dataset class uses that just has the expected keys output by the pytorch dataset class

    @dataclasses.dataclass
    class WindowBatch:
    """A dataclass representing a batch of temporal data for a Pytorch model.
    
    Attributes:
        code: A long tensor of shape (batch_size, sequence_length) containing categorical codes/indices
        mask: A boolean tensor of shape (batch_size, sequence_length) indicating which elements are valid (not padding)
        numeric_value: A float tensor of shape (batch_size, sequence_length) containing numeric values
        numeric_value_mask: A boolean tensor of shape (batch_size, sequence_length) indicating which numeric values were observed
        prepended_static_mask: A boolean tensor of shape (batch_size, sequence_length) indicating which elements are prepended static data
        time_delta_days: A float tensor of shape (batch_size, sequence_length) indicating days between events
        static_indices: A long tensor of shape (batch_size, n_static) containing static categorical codes 
        static_values: A float tensor of shape (batch_size, n_static) containing static numeric values
        static_mask: A boolean tensor of shape (batch_size, n_static) indicating which static elements are valid
        static_numeric_value_mask: A boolean tensor of shape (batch_size, n_static) indicating which static numeric values are valid
        label: A float tensor of shape (batch_size,) containing classification labels
    """
    
    code: torch.LongTensor
    mask: torch.BoolTensor
    numeric_value: torch.FloatTensor
    numeric_value_mask: torch.BoolTensor
    prepended_static_mask: torch.BoolTensor  # Added this field
    time_delta_days: torch.FloatTensor
    static_indices: torch.LongTensor | None = None
    static_values: torch.FloatTensor | None = None 
    static_mask: torch.BoolTensor | None = None
    static_numeric_value_mask: torch.BoolTensor | None = None
    label: torch.FloatTensor | None = None
  2. Create WindowBatch Class for the MultiWindow and RandomWindow datasets:

    @dataclass
    class MultiWindowBatch:
    windows: dict[str, WindowBatch]
    primary_window: str | None = None
    early_fused_window: WindowBatch | None = None
    
    def fuse_windows(self, windows_to_fuse: list[str], fused_name: str):
        """Efficiently fuse windows, handling static data properly"""
        ...
  3. Implementing a tokenization method specific to_meds function. I added one to the eic_forecasting model and will one to the triplet model soon: #107. We should centralize these in the Window classes I think.

This will give us:

  1. Type safety through dataclass structures
  2. Memory efficient early window fusion by avoiding static data duplication
  3. Consistent tensor management across codebase
  4. Cleaner API for window operations

Implementation Steps:

Does this look right @mmcdermott ?