Closed Oufattole closed 2 weeks ago
The best practice would be to validate the schema as well like in meds_transform which validates the extracted data follows the meds-schema. Here is a link to where meds_transform does this. Here is where the meds schema is defined.
We should add schema code like this in src/meds_torch/schemas/generate_analysis_schema.py
ACTUAL_FUTURE = "ACTUAL_FUTURE"
GENERATE_PREFIX = "GENERATE//"
INPUT_DATA = "INPUT_DATA"
trajectory_data_type = pa.struct([
('code', pa.int64()),
('mask', pa.bool_()),
('numeric_value', pa.float64()),
('numeric_value_mask', pa.bool_()),
('time_delta_days', pa.float64()),
])
def generation_analysis_schema(num_samples=0, do_include_actual):
return pa.schema(
[
("subject_id", pa.int64()),
("prediction_time", pa.timestamp("us")),
(INPUT_DATA, trajectory_data_type)
]
+ [(GENERATE_PREFIX + str(i), trajectory_data_type) for i in range(num_samples)]
+ [(ACTUAL_FUTURE, trajectory_data_type)] if do_include_actual else []
)
And at the end of the generate_trajectory script we should validate the generated dataframe as follows:
validated_schema = generation_analysis_schema(num_samples=cfg.model.num_samples, do_include_actual=(ACTUAL_FUTURE in df.columns))
return df.collect().to_arrow().cast(validated_schema)
I'm also quite sure I need to modify trajectory_data_type
as all the fields will actually be lists.
Currently, we only log the autoregressive loss, but should add some simple metrics:
Additionally, we need to support both analysis workflows on the generated data. We should have a separate generation script that iterates through the data, conditioned on past data up to a prediction_time (so this script expects a task parquet to be provided with prediction_times) and creates a parquet file with the following schema:
The generation script should
generation_analysis
schemaESGPT has this script that performs this.
Finally, we need to add a multiwindow dataset class with pre data prior to the task prediction time and post data after as this will be used for unconstrained generative evaluation (see #106). Windows are currently keys in the batch dictionary. I need to add a default window name attribute, which will flatten a default window to be at the top level of the batch dictionary, so current models (which expect top level batch keys) are still supported.
TODOS