Oufattole / meds-torch

MIT License
15 stars 1 forks source link

Add a predict script that generates embeddings #109

Closed Oufattole closed 2 weeks ago

Oufattole commented 4 weeks ago

Add a general predict script for all models that allows storing embeddings, logits, loss and predictions in a meds-evaluation compliant schema.

This is the meds-eval schema:

predicted_labels = pa.schema(
    [
        ("subject_id", pa.int64()),
        ("prediction_time", pa.timestamp("us")),
        ("boolean_value", pa.bool_()),
        ("predicted_boolean_value", pa.bool_()),
        ("predicted_boolean_probability", pa.float64()),
    ]
)

I will add optional logits, embeddings, and loss:

predicted_labels = pa.schema(
    [
        # Required
        ("subject_id", pa.int64()),
        ("prediction_time", pa.timestamp("us")),
        # Optional (you must have all three for prediction)
        ("boolean_value", pa.bool_()),
        ("predicted_boolean_value", pa.bool_()),
        ("predicted_boolean_probability", pa.float64()),
        # Optional
        ("embeddings", pa._list(pa.float64())),
        ("logits_sequence", pa._list(pa._list(pa.float64()))),
        ("logits", pa._list(pa.float64())),
        ("loss", pa.float64()),
    ]
)

Also you should validate schemas as described here.

TODOS