UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.19k stars 2.47k forks source link

How to customize the dataloader? e.g. Custom Data Augmentation #3015

Open msciancalepore98 opened 1 week ago

msciancalepore98 commented 1 week ago

Hi,

I've always been used to the old .fit behaviour where I could pass in the good DataLoader, implementing the Dataset myself, according to my needs.

With the new trainer interface, how am I supposed to tweak the dataloader?

Let's say I want to apply some random transformations to the input text, how can I do it right now? Of course, changing the original dataset, augmenting it statically, is a no-go.

Thanks!

thearod5 commented 1 week ago

I have the same exact question, I really like having access to the data loader and its accessories (batch sampling). I have an easy use case where I have a class imbalance so I would like to evaluate undersampling / oversampling but there is no way to do to my knowledge. Let me know if you succeed or find clues to modify the data loader

thearod5 commented 1 week ago

I solved it! Sub-class their trainer and re-implement the get_train_dataloader method to inject your custom implementation!

from sentence_transformers import SentenceTransformerTrainer
from torch.utils.data import DataLoader

from trace_model_trainer.models.st.balanced_data_loader import BalancedSampler

class CustomTrainer(SentenceTransformerTrainer):
    def get_train_dataloader(self) -> DataLoader:
        data_loader = super().get_train_dataloader() # ran this to only override the sampler
        return DataLoader(
            self.train_dataset,
            batch_sampler=CustomSampler(self.train_dataset),
            collate_fn=data_loader.collate_fn,
            num_workers=data_loader.num_workers,
            pin_memory=data_loader.pin_memory,
            persistent_workers=data_loader.persistent_workers,
            prefetch_factor=data_loader.prefetch_factor,
            timeout=data_loader.timeout,
            worker_init_fn=data_loader.worker_init_fn,
            multiprocessing_context=data_loader.multiprocessing_context,
            pin_memory_device=data_loader.pin_memory_device
            # batch_size, shuffle, sampler, drop_last defined in sampler
        )
tomaarsen commented 2 days ago

Hello!

There are 2 ways to do this - @thearod5 already found the first one, well done, and the second one consists of adding a transform to the training (and/or evaluation) dataset. This is a function that gets called over a batch whenever that batch is requested from the Dataset. See the set_transform docs here.

Here is a complete example that uses it:

# pip install nltk textblob==0.15.3 scipy==1.10.1 textaugment

import random
import logging
from typing import Dict, List, Literal
from datasets import load_dataset, Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator

import nltk
nltk.download('wordnet')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

from textaugment import Wordnet

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)

# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
    "microsoft/mpnet-base",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="MPNet base trained on GooAQ triplets",
    ),
)

# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/gooaq", split="train")
dataset = dataset.add_column("id", range(len(dataset)))
dataset_dict = dataset.train_test_split(test_size=10_000, seed=12)
train_dataset: Dataset = dataset_dict["train"]
eval_dataset: Dataset = dataset_dict["test"]

# or Translate, or whatever you want to do!
augmenter = Wordnet()
def augment_text(batch: Dict[Literal["question", "answer"], List[str]]) -> Dict[str, str]:
    questions = batch["question"]
    answers = batch["answer"]

    questions = [augmenter.augment(question) if random.random() < 0.5 else question for question in questions]
    answers = [augmenter.augment(answer) if random.random() < 0.5 else answer for answer in answers]

    print("augment_text called!")
    return {
        "question": questions,
        "answer": answers,
    }

train_dataset.set_transform(augment_text)

# 4. Define a loss function
loss = MultipleNegativesRankingLoss(model)

# 5. (Optional) Specify training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/mpnet-base-gooaq",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    logging_steps=250,
    logging_first_step=True,
    run_name="mpnet-base-gooaq",  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
# The full corpus, but only the evaluation queries
# corpus = dict(zip(dataset["id"], dataset["answer"]))
random.seed(12)
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
corpus = (
    {qid: dataset[qid]["answer"] for qid in queries} |
    {qid: dataset[qid]["answer"] for qid in random.sample(range(len(dataset)), 20_000)}
)
relevant_docs = {qid: {qid} for qid in eval_dataset["id"]}
dev_evaluator = InformationRetrievalEvaluator(
    corpus=corpus,
    queries=queries,
    relevant_docs=relevant_docs,
    show_progress_bar=True,
    name="gooaq-dev",
)
# Uncomment this line to evaluate the base model; will be included in model card
# dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset.remove_columns("id"),
    eval_dataset=eval_dataset.remove_columns("id"),
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)

# 8. Save the trained model
model.save_pretrained("models/mpnet-base-gooaq/final")

# 9. (Optional) Push it to the Hugging Face Hub
# model.push_to_hub("mpnet-base-gooaq")

(Note, apparently textaugment is quite outdated, so I added the working dependencies at the top)

This allows you to augment your data on the fly. Every epoch will be different as a result, especially if you apply your augmentation(s) with some probability.

I hope this helps your use case.

thearod5 commented 2 days ago

This is a cool approach, thank you!!