Open msciancalepore98 opened 1 week ago
I have the same exact question, I really like having access to the data loader and its accessories (batch sampling). I have an easy use case where I have a class imbalance so I would like to evaluate undersampling / oversampling but there is no way to do to my knowledge. Let me know if you succeed or find clues to modify the data loader
I solved it! Sub-class their trainer and re-implement the get_train_dataloader
method to inject your custom implementation!
from sentence_transformers import SentenceTransformerTrainer
from torch.utils.data import DataLoader
from trace_model_trainer.models.st.balanced_data_loader import BalancedSampler
class CustomTrainer(SentenceTransformerTrainer):
def get_train_dataloader(self) -> DataLoader:
data_loader = super().get_train_dataloader() # ran this to only override the sampler
return DataLoader(
self.train_dataset,
batch_sampler=CustomSampler(self.train_dataset),
collate_fn=data_loader.collate_fn,
num_workers=data_loader.num_workers,
pin_memory=data_loader.pin_memory,
persistent_workers=data_loader.persistent_workers,
prefetch_factor=data_loader.prefetch_factor,
timeout=data_loader.timeout,
worker_init_fn=data_loader.worker_init_fn,
multiprocessing_context=data_loader.multiprocessing_context,
pin_memory_device=data_loader.pin_memory_device
# batch_size, shuffle, sampler, drop_last defined in sampler
)
Hello!
There are 2 ways to do this - @thearod5 already found the first one, well done, and the second one consists of adding a transform
to the training (and/or evaluation) dataset. This is a function that gets called over a batch whenever that batch is requested from the Dataset
. See the set_transform
docs here.
Here is a complete example that uses it:
# pip install nltk textblob==0.15.3 scipy==1.10.1 textaugment
import random
import logging
from typing import Dict, List, Literal
from datasets import load_dataset, Dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator
import nltk
nltk.download('wordnet')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
from textaugment import Wordnet
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
"microsoft/mpnet-base",
model_card_data=SentenceTransformerModelCardData(
language="en",
license="apache-2.0",
model_name="MPNet base trained on GooAQ triplets",
),
)
# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/gooaq", split="train")
dataset = dataset.add_column("id", range(len(dataset)))
dataset_dict = dataset.train_test_split(test_size=10_000, seed=12)
train_dataset: Dataset = dataset_dict["train"]
eval_dataset: Dataset = dataset_dict["test"]
# or Translate, or whatever you want to do!
augmenter = Wordnet()
def augment_text(batch: Dict[Literal["question", "answer"], List[str]]) -> Dict[str, str]:
questions = batch["question"]
answers = batch["answer"]
questions = [augmenter.augment(question) if random.random() < 0.5 else question for question in questions]
answers = [augmenter.augment(answer) if random.random() < 0.5 else answer for answer in answers]
print("augment_text called!")
return {
"question": questions,
"answer": answers,
}
train_dataset.set_transform(augment_text)
# 4. Define a loss function
loss = MultipleNegativesRankingLoss(model)
# 5. (Optional) Specify training arguments
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir="models/mpnet-base-gooaq",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
save_total_limit=2,
logging_steps=250,
logging_first_step=True,
run_name="mpnet-base-gooaq", # Will be used in W&B if `wandb` is installed
)
# 6. (Optional) Create an evaluator & evaluate the base model
# The full corpus, but only the evaluation queries
# corpus = dict(zip(dataset["id"], dataset["answer"]))
random.seed(12)
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
corpus = (
{qid: dataset[qid]["answer"] for qid in queries} |
{qid: dataset[qid]["answer"] for qid in random.sample(range(len(dataset)), 20_000)}
)
relevant_docs = {qid: {qid} for qid in eval_dataset["id"]}
dev_evaluator = InformationRetrievalEvaluator(
corpus=corpus,
queries=queries,
relevant_docs=relevant_docs,
show_progress_bar=True,
name="gooaq-dev",
)
# Uncomment this line to evaluate the base model; will be included in model card
# dev_evaluator(model)
# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset.remove_columns("id"),
eval_dataset=eval_dataset.remove_columns("id"),
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)
# 8. Save the trained model
model.save_pretrained("models/mpnet-base-gooaq/final")
# 9. (Optional) Push it to the Hugging Face Hub
# model.push_to_hub("mpnet-base-gooaq")
(Note, apparently textaugment
is quite outdated, so I added the working dependencies at the top)
This allows you to augment your data on the fly. Every epoch will be different as a result, especially if you apply your augmentation(s) with some probability.
I hope this helps your use case.
This is a cool approach, thank you!!
Hi,
I've always been used to the old .fit behaviour where I could pass in the good DataLoader, implementing the Dataset myself, according to my needs.
With the new trainer interface, how am I supposed to tweak the dataloader?
Let's say I want to apply some random transformations to the input text, how can I do it right now? Of course, changing the original dataset, augmenting it statically, is a no-go.
Thanks!