huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133k stars 26.54k forks source link

Data Map Trainer Callback #31647

Open nbertagnolli opened 3 months ago

nbertagnolli commented 3 months ago

Feature request

It would be nice to have a callback for the trainer class which could create Data Maps. See the paper for more details https://arxiv.org/pdf/2009.10795. A Data Map measures how a model's prediction of specific training data change over the course of model training.

The Callback should support:

Running this colab notebook I made will generate data map outputs for classification tasks using the Trainer in line with what I was thinking. Here is what I have so far that works will for multilabel and multiclass classification using transformers.

class DataMapCallback(TrainerCallback):
    """Trainer Callback to save DataMap data.

    Original Paper: https://arxiv.org/pdf/2009.10795.pdf.

    This callback saves the predictions of the model on each training example
    at the end of every epoch to callback_dir/{epoch}.json.
    """

    def __init__(
        self,
        log_on: str = "epoch",
        callback_dir: str = ".",
        n_log_steps: Optional[int] = None,
        prediction_fn: Optional[Callable[[PreTrainedModel, DataLoader, TrainingArguments], List[List[float]]]] = None,
    ):
        self.callback_dir = callback_dir
        self.log_on = log_on
        self.log_count = 0
        self.n_log_steps = n_log_steps
        self.prediction_fn = self._predict if prediction_fn is None else prediction_fn

        # Handle discrepencies in how we initialize the logging mode.
        if n_log_steps is not None and self.log_on != "step":
            raise ValueError(
                "n_log_steps is only valid when on='step'.  If you want to to run datamaps based on steps please specify on='step'."
            )
        elif n_log_steps is None and self.log_on == "step":
            warnings.warn(
                "You have not specified n_log_steps.  This will result in a large number of datamaps being saved setting step size to 1."
            )
            self.n_log_steps = 1

        # Create the directory if it doesn't exist.
        if not os.path.exists(self.callback_dir):
            os.makedirs(self.callback_dir, exist_ok=True)

    def _predict(self, model, train_data_loader, args):
        if train_data_loader.batch_size is None:
          batch_size = args.per_device_train_batch_size
        else:
          batch_size = train_data_loader.batch_size
        batches = BatchSampler(
            SequentialSampler(train_data_loader.dataset),
            batch_size,
            False,
        )

        with torch.no_grad():
            predictions = []
            for batch in batches:
                # Adjust the indices to include the last element because python is [)
                start_idx, end_idx = batch[0], batch[-1] + 1

                # Extract the sample from the training dataset
                sample = train_data_loader.dataset[start_idx:end_idx]

                # Make sure to apply any data collators
                sample = train_data_loader.collate_fn(sample)

                # Move all samples to the appropriate device. We only do this
                # For args that are part of the model and the dataset.
                args = set(inspect.getfullargspec(model.forward).args).intersection(
                    set(sample.keys())
                )
                sample = {k: torch.tensor(sample[k]).to(model.device) for k in args}

                # Perform inference using the model
                current_preds = model(**sample)

                # Convert the predictions to a list and append them to the result
                predictions += current_preds.logits.tolist()

        return predictions

    def _save_predictions(self, model, train_data_loader, args):
        predictions = self.prediction_fn(model, train_data_loader, args)

        # Save Predictions.
        with open(os.path.join(self.callback_dir, f"{self.log_count}.json"), "w") as f:
            json.dump(predictions, f)
        self.log_count += 1

    def on_epoch_end(self, args, state, control, logs=None, **kwargs):
        """Predict on all training examples at the end of an epoch."""

        if self.log_on == "epoch":
            self._save_predictions(kwargs["model"], kwargs["train_dataloader"], args)

    def on_save(self, args, state, control, logs=None, **kwargs):
        """Predict on all training examples when a checkpoint is saved"""
        if self.log_on == "save":
            self._save_predictions(kwargs["model"], kwargs["train_dataloader"], args)

    def on_evaluate(self, args, state, control, logs=None, **kwargs):
        """Predict on all training examples when we run an evaluation loop."""
        if self.log_on == "evaluate":
            self._save_predictions(kwargs["model"], kwargs["train_dataloader"], args)

    def on_step_end(self, args, state, control, logs=None, **kwargs):
        """Predict at the end of a step."""
        if self.log_on == "step":
            self._save_predictions(kwargs["model"], kwargs["train_dataloader"], args)

Motivation

Optimizing data is as important as correctly configuring your model. Data Maps are an incredibly powerful tool which help us understand the data we are using to train specific tasks. Using Tensorboard to monitor the loss during training can identify many bugs. This is a technique which can be equally valuable. In my day to day this technique has seriously increased the performance of production models I've trained at multiple different companies. I think it's really useful for gaining insights about your data and also pushing the limits of your performance. I want to see everyone get the same benefits I've seen. I wrote a blog on doing this with sklearn if you want to see a simple example.

Your contribution

I'd love to contribute this. I have already created a working prototype with this colab notebook. It will generate data map outputs for classification tasks using the Trainer. I'm working on an example for non classification tasks as well. If you'd be willing to guide me on this addition, and you think it's valuable, I'd do as much of this as possible : ).

amyeroberts commented 3 months ago

cc @muellerzr @SunMarc

muellerzr commented 3 months ago

@nbertagnolli nice idea! Feel free to open a draft PR if you'd like adding it to the callbacks :)

nbertagnolli commented 2 months ago

I'll get started on that! Thanks for the encouragement : )