Open sachinya00 opened 2 months ago
Hi @sachinya00, thanks for opening a feature request!
Anyone can write their own callbacks, inheriting from TrainerCallback and pass them to the trainer class.
If you or anyone else would like to open a PR to add a specific functionality we'd be happy to review!
cc @muellerzr @SunMarc
TrainerCallback is not the ideal solution when we only want to perform inference (no training involved) as it only allows to customize its behavior at various training stages.
Currently I am overriding the prediction_step
method in a custom Trainer class to write predictions to a file
class CustomTrainer(Trainer):
def __init__(self, *args, prediction_writer=None, **kwargs):
"""
Initializes the CustomTrainer.
Args:
prediction_writer (PredictionWriter): Instance of PredictionWriter for saving predictions.
"""
super().__init__(*args, **kwargs)
self.prediction_writer = prediction_writer
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
"""
Overrides the prediction step to include saving predictions.
Args:
model (PreTrainedModel): The model to use for predictions.
inputs (dict): The inputs to the model.
prediction_loss_only (bool): Whether to return only the loss.
ignore_keys (List[str]): Keys to ignore for the model output.
Returns:
tuple: (loss, logits, labels)
"""
# Call the original prediction step
loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
# Save predictions if a prediction writer is provided
if self.prediction_writer:
self.prediction_writer.save_predictions(pred_batch=logits, labels=labels, input_ids=inputs['input_ids'])
return loss, logits, labels
I am exploring distributed inference capabilities with the Hugging Face Trainer for transformers. I need to do distributed inference across multiple devices or nodes and save the predictions to a file. However, after reviewing the available callbacks, I did not find any that facilitate this specific task. Furthermore, when using the trainer.predict method, I noticed that it returns only the labels and predictions, without including the original input batches used for inference.
PyTorch Lightning offers a flexible mechanism for handling prediction outputs using custom callbacks. For example, the following PyTorch Lightning code snippet demonstrates how a custom BasePredictionWriter callback can be implemented to save predictions to files: