huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.41k stars 26.37k forks source link

callback to implement how the predictions should be stored. #32145

Open sachinya00 opened 2 months ago

sachinya00 commented 2 months ago

I am exploring distributed inference capabilities with the Hugging Face Trainer for transformers. I need to do distributed inference across multiple devices or nodes and save the predictions to a file. However, after reviewing the available callbacks, I did not find any that facilitate this specific task. Furthermore, when using the trainer.predict method, I noticed that it returns only the labels and predictions, without including the original input batches used for inference.

PyTorch Lightning offers a flexible mechanism for handling prediction outputs using custom callbacks. For example, the following PyTorch Lightning code snippet demonstrates how a custom BasePredictionWriter callback can be implemented to save predictions to files:

import os
from lightning.pytorch.callbacks import BasePredictionWriter

class CustomWriter(BasePredictionWriter):

    def __init__(self, output_dir, write_interval):
        super().__init__(write_interval)
        self.output_dir = output_dir

    def write_on_batch_end(
        self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx
    ):
        torch.save(prediction, os.path.join(self.output_dir, str(dataloader_idx), f"{batch_idx}.pt"))

    def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
        torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))

pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
trainer = Trainer(callbacks=[pred_writer])
model = BoringModel()
trainer.predict(model, return_predictions=False)
amyeroberts commented 2 months ago

Hi @sachinya00, thanks for opening a feature request!

Anyone can write their own callbacks, inheriting from TrainerCallback and pass them to the trainer class.

If you or anyone else would like to open a PR to add a specific functionality we'd be happy to review!

cc @muellerzr @SunMarc

sachinya00 commented 2 months ago

TrainerCallback is not the ideal solution when we only want to perform inference (no training involved) as it only allows to customize its behavior at various training stages. Currently I am overriding the prediction_step method in a custom Trainer class to write predictions to a file

class CustomTrainer(Trainer):
    def __init__(self, *args, prediction_writer=None, **kwargs):
        """
        Initializes the CustomTrainer.

        Args:
            prediction_writer (PredictionWriter): Instance of PredictionWriter for saving predictions.
        """
        super().__init__(*args, **kwargs)
        self.prediction_writer = prediction_writer

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        """
        Overrides the prediction step to include saving predictions.

        Args:
            model (PreTrainedModel): The model to use for predictions.
            inputs (dict): The inputs to the model.
            prediction_loss_only (bool): Whether to return only the loss.
            ignore_keys (List[str]): Keys to ignore for the model output.

        Returns:
            tuple: (loss, logits, labels)
        """
        # Call the original prediction step
        loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)

        # Save predictions if a prediction writer is provided
        if self.prediction_writer:
            self.prediction_writer.save_predictions(pred_batch=logits, labels=labels, input_ids=inputs['input_ids'])

        return loss, logits, labels