Oufattole / meds-torch

MIT License
12 stars 1 forks source link

Fairness Evaluation #48

Open Oufattole opened 5 days ago

Oufattole commented 5 days ago

Currently if you want to train and evaluate a model you run the following script:

set -e  # Exit immediately if a command exits with a non-zero status.
source $(conda info --base)/etc/profile.d/conda.sh
conda activate fairness
ROOT_DIR=??? # Add a Root directory with your meds folder

# Train a model
meds-torch-train \
        experiment=$experiment paths.data_dir=${TENSOR_DIR} \
        paths.meds_cohort_dir=${MEDS_DIR} paths.output_dir=${TRAIN_DIR} \
        data.task_name=$task_name data.task_root_dir=$TASKS_DIR \
# Evaluate a model
meds-torch-eval \
        experiment=$experiment paths.data_dir=${TENSOR_DIR} \
        paths.meds_cohort_dir=${MEDS_DIR} paths.output_dir=${TRAIN_DIR} \
        data.task_name=$task_name data.task_root_dir=$TASKS_DIR \

And you get results in $(meds-torch-latest-dir path=${TRAIN_DIR})/results_summary.parquet that include test/auc among other metrics on the full dataset. It would be useful to have a fairness evaluation where a user can define a subgroup of patients to store these metrics for.

Ideally a user can add the kwarg eval_groups=CODE//NAME and we evaluate metrics (like AUC) specific to that group, rather than everyone in the dataset. This kwarg can be added by modifying the eval.yaml. To start, we can assume CODE//NAME is a static feature, such as the codeGENDER in MIMICIV, which would be stored in the static_df. So I think we need to add some multiclass labels to the batch in the pytorch_dataset class based on this.

I think we just need to update the test_step function in the SupervisedModule here like so:

    def test_step(self, batch, batch_idx):
        output: OutputBase = self.forward(batch)
        # logs metrics for each training_step, and the average across the epoch
        self.test_acc.update(output.logits.squeeze(), batch[self.task_name].float())
        self.test_auc.update(output.logits.squeeze(), batch[self.task_name].float())
        self.test_apr.update(output.logits.squeeze(), batch[self.task_name].int())
        if self.config.eval_groups:
                ... # compute and log fairness metrics

        self.log("test/loss", output.loss, batch_size=self.cfg.batch_size)
        return output.loss
Oufattole commented 5 days ago

Fixing issue #9 would be helpful here