Oufattole / meds-torch

MIT License
15 stars 1 forks source link

Add API support for generation that allow generating data after a prediction time and caching both the predictions and the real data. #104

Closed Oufattole closed 2 weeks ago

Oufattole commented 1 month ago

Currently, we only log the autoregressive loss, but should add some simple metrics:

Additionally, we need to support both analysis workflows on the generated data. We should have a separate generation script that iterates through the data, conditioned on past data up to a prediction_time (so this script expects a task parquet to be provided with prediction_times) and creates a parquet file with the following schema:

trajectory_data_type = pa.struct([
    ('code', pa.int64()),
    ('mask', pa.bool_()),
    ('numeric_value', pa.float64()),
    ('numeric_value_mask', pa.bool_()),
    ('time_delta_days', pa.float64()),
])

generation_analysis = pa.schema(
    [
        ("subject_id", pa.int64()),
        ("prediction_time", pa.timestamp("us")),
        ("actual", trajectory_data_type),
        ("generation_1", trajectory_data_type),
        ...
        ("generation_N", trajectory_data_type),
    ]
)

The generation script should

  1. Use trainer.predict, which returns all the predictions
  2. Store these in a parquet following the generation_analysis schema

ESGPT has this script that performs this.

Finally, we need to add a multiwindow dataset class with pre data prior to the task prediction time and post data after as this will be used for unconstrained generative evaluation (see #106). Windows are currently keys in the batch dictionary. I need to add a default window name attribute, which will flatten a default window to be at the top level of the batch dictionary, so current models (which expect top level batch keys) are still supported.

TODOS

Oufattole commented 3 weeks ago

The best practice would be to validate the schema as well like in meds_transform which validates the extracted data follows the meds-schema. Here is a link to where meds_transform does this. Here is where the meds schema is defined.

We should add schema code like this in src/meds_torch/schemas/generate_analysis_schema.py

ACTUAL_FUTURE = "ACTUAL_FUTURE"
GENERATE_PREFIX = "GENERATE//"
INPUT_DATA = "INPUT_DATA"

trajectory_data_type = pa.struct([
    ('code', pa.int64()),
    ('mask', pa.bool_()),
    ('numeric_value', pa.float64()),
    ('numeric_value_mask', pa.bool_()),
    ('time_delta_days', pa.float64()),
])

def generation_analysis_schema(num_samples=0, do_include_actual):
    return pa.schema(
        [
            ("subject_id", pa.int64()),
            ("prediction_time", pa.timestamp("us")),
            (INPUT_DATA, trajectory_data_type)
        ]
        + [(GENERATE_PREFIX + str(i), trajectory_data_type) for i in range(num_samples)]
        + [(ACTUAL_FUTURE, trajectory_data_type)] if do_include_actual else []
    )

And at the end of the generate_trajectory script we should validate the generated dataframe as follows:

validated_schema = generation_analysis_schema(num_samples=cfg.model.num_samples, do_include_actual=(ACTUAL_FUTURE in df.columns))
return df.collect().to_arrow().cast(validated_schema)

I'm also quite sure I need to modify trajectory_data_type as all the fields will actually be lists.