Lightning-Universe / lightning-transformers

Flexible components pairing 🤗 Transformers with :zap: Pytorch Lightning
https://lightning-transformers.readthedocs.io
Apache License 2.0
610 stars 77 forks source link

Callback for writing predictions #205

Closed yuvalkirstain closed 2 years ago

yuvalkirstain commented 2 years ago

🚀 Feature

Add a callback to log the predictions of the model in the evaluation stage.

Motivation

When someone is developing a model, it can be helpful to look at its predictions. It can be pretty wasteful to run predictions during evaluation (generation for thousands of examples with a very large model) and then run the exact same model once again using a different script to get those predictions. I suggest logging the predictions after each evaluation loop.

Pitch

Sometimes aggregated metrics are not enough and we need to see with our own eyes what our models predict. Can you please add a tool that will allow us to do that during training?

mathemusician commented 2 years ago

@yuvalkirstain How did you imagine this feature to be used?

justusschock commented 2 years ago

@yuvalkirstain Would something like this work for you? It is hard to generalize this, because outputs can be of varying types, but with extending this existing callback you should be ready to go pretty fast :)

yuvalkirstain commented 2 years ago

@mathemusician Thanks for the help! :) How the feature should be used: I'd like to pass an argument as to which predictions should be saved, similarly to which checkpoints should be saved. For example, I want to save only the predictions from the validation epoch in which the model performed best, or I want to save the predictions from all validation epochs. Then, I will be able to access those predictions and know their corresponding step/epoch (similarly to the way we can access the saved checkpoints).

If we consider only text-to-text models (like GPT and T5) it should be easy to aggregate their inputs and predictions during validation and then to save those predictions. For example, if we consider translation, the compute_generate_metrics function is a good place to add the inputs and predictions, and after the validation epoch ends it can automatically save those predictions.

@justusschock Thanks for the help! :) Is there an existing example that uses this class to log predictions that are generated during validation? I am not sure how to extend this base class for my purposes.

justusschock commented 2 years ago

@yuvalkirstain I am not sure if we have an official example other than the one given in the documentation I linked. I am afraid, that it currently only supports predictions generated with trainer.predict because the indices of the current batch are not available during validation.

If you have your own logic on how to calculate them, you could use something like this however:


class ValidationPredictionWriter(BasePredictionWriter):
        def on_validation_batch_end(
            self,
            trainer: "pl.Trainer",
            pl_module: "pl.LightningModule",
            outputs: Optional[STEP_OUTPUT],
            batch: Any,
            batch_idx: int,
            dataloader_idx: int,
        ) -> None:
            """Called when the validation batch ends."""
            if not self.interval.on_batch:
                return
            # TODO: replace ... with your custom logic or None if you don't need this for saving.
            batch_indices = ... # originally: trainer.predict_loop.epoch_loop.current_batch_indices
            self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx)

Such custom logic could for example be to include the batch indices in your batch.

yuvalkirstain commented 2 years ago

@justusschock Thank you so much for this suggestion! I will sometime soon probably try to add it to my script. If it will work well I will post it here and close the issue.

yuvalkirstain commented 2 years ago

We implemented it and it was easier than we thought. Thanks of much for the help. I'm posting it here in case someone will want to use it :)

class PredictionWriter(BasePredictionWriter):
    def __init__(self, dir_name: str, write_val_preds: bool):
        super().__init__()
        self.val_out_file = None
        self.write_val_preds = write_val_preds

        if self.write_val_preds:
            self.val_output_dir = Path(dir_name) / "val"
            os.makedirs(self.val_output_dir, exist_ok=True)

    def on_val_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        if self.write_val_preds:
            _write_outputs(self.val_out_file, outputs)

    def on_val_start(self, trainer, pl_module):
        if self.write_val_preds:
            self.val_out_file = open(self.val_output_dir / 'val_preds.jsonl', 'w')

    def on_val_end(self, trainer, pl_module):
        if self.write_val_preds:
            self.val_out_file.close()

def _write_outputs(out_file, outputs):
    for output in outputs:
        out_file.write(json.dumps(output) + "\n")