Flexible Training and Retrieval for Late Interaction Models
PyLate is a library built on top of Sentence Transformers, designed to simplify and optimize fine-tuning, inference, and retrieval with state-of-the-art ColBERT models. It enables easy fine-tuning on both single and multiple GPUs, providing flexibility for various hardware setups. PyLate also streamlines document retrieval and allows you to load a wide range of models, enabling you to construct ColBERT models from most pre-trained language models.
You can install PyLate using pip:
pip install pylate
For evaluation dependencies, use:
pip install "pylate[eval]"
The complete documentation is available here, which includes in-depth guides, examples, and API references.
Here’s a simple example of training a ColBERT model on the MS MARCO dataset triplet dataset using PyLate. This script demonstrates training with contrastive loss and evaluating the model on a held-out eval set:
import torch
from datasets import load_dataset
from sentence_transformers import (
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from pylate import evaluation, losses, models, utils
# Define model parameters for contrastive training
model_name = "bert-base-uncased" # Choose the pre-trained model you want to use as base
batch_size = 32 # Larger batch size often improves results, but requires more memory
num_train_epochs = 1 # Adjust based on your requirements
# Set the run name for logging and output directory
run_name = "contrastive-bert-base-uncased"
output_dir = f"output/{run_name}"
# 1. Here we define our ColBERT model. If not a ColBERT model, will add a linear layer to the base encoder.
model = models.ColBERT(model_name_or_path=model_name)
# Compiling the model makes the training faster
model = torch.compile(model)
# Load dataset
dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train")
# Split the dataset (this dataset does not have a validation set, so we split the training set)
splits = dataset.train_test_split(test_size=0.01)
train_dataset = splits["train"]
eval_dataset = splits["test"]
# Define the loss function
train_loss = losses.Contrastive(model=model)
# Initialize the evaluator
dev_evaluator = evaluation.ColBERTTripletEvaluator(
anchors=eval_dataset["query"],
positives=eval_dataset["positive"],
negatives=eval_dataset["negative"],
)
# Configure the training arguments (e.g., batch size, evaluation strategy, logging steps)
args = SentenceTransformerTrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
run_name=run_name, # Will be used in W&B if `wandb` is installed
learning_rate=3e-6,
)
# Initialize the trainer for the contrastive training
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=dev_evaluator,
data_collator=utils.ColBERTCollator(model.tokenize),
)
# Start the training process
trainer.train()
After training, the model can be loaded using the output directory path:
from pylate import models
model = models.ColBERT(model_name_or_path="contrastive-bert-base-uncased")
To get the best performance when training a ColBERT model, you should use knowledge distillation to train the model using the scores of a strong teacher model. Here's a simple example of how to train a model using knowledge distillation in PyLate on MS MARCO:
import torch
from datasets import load_dataset
from sentence_transformers import (
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from pylate import losses, models, utils
# Load the datasets required for knowledge distillation (train, queries, documents)
train = load_dataset(
path="lightonai/ms-marco-en-bge",
name="train",
)
queries = load_dataset(
path="lightonai/ms-marco-en-bge",
name="queries",
)
documents = load_dataset(
path="lightonai/ms-marco-en-bge",
name="documents",
)
# Set the transformation to load the documents/queries texts using the corresponding ids on the fly
train.set_transform(
utils.KDProcessing(queries=queries, documents=documents).transform,
)
# Define the base model, training parameters, and output directory
model_name = "bert-base-uncased" # Choose the pre-trained model you want to use as base
batch_size = 16
num_train_epochs = 1
# Set the run name for logging and output directory
run_name = "knowledge-distillation-bert-base"
output_dir = f"output/{run_name}"
# Initialize the ColBERT model from the base model
model = models.ColBERT(model_name_or_path=model_name)
# Compiling the model to make the training faster
model = torch.compile(model)
# Configure the training arguments (e.g., epochs, batch size, learning rate)
args = SentenceTransformerTrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=batch_size,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
run_name=run_name,
learning_rate=1e-5,
)
# Use the Distillation loss function for training
train_loss = losses.Distillation(model=model)
# Initialize the trainer
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train,
loss=train_loss,
data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize),
)
# Start the training process
trainer.train()
PyLate supports Hugging Face Datasets, enabling seamless triplet / knowledge distillation based training. For contrastive training, you can use any of the existing sentence transformers triplet datasets. Below is an example of creating a custom triplet dataset for training:
from datasets import Dataset
dataset = [
{
"query": "example query 1",
"positive": "example positive document 1",
"negative": "example negative document 1",
},
{
"query": "example query 2",
"positive": "example positive document 2",
"negative": "example negative document 2",
},
{
"query": "example query 3",
"positive": "example positive document 3",
"negative": "example negative document 3",
},
]
dataset = Dataset.from_list(mapping=dataset)
train_dataset, test_dataset = dataset.train_test_split(test_size=0.3)
To create a knowledge distillation dataset, you can use the following snippet:
from datasets import Dataset
dataset = [
{
"query_id": 54528,
"document_ids": [
6862419,
335116,
339186,
],
"scores": [
0.4546215673141326,
0.6575686537173476,
0.26825184192900203,
],
},
{
"query_id": 749480,
"document_ids": [
6862419,
335116,
339186,
],
"scores": [
0.2546215673141326,
0.7575686537173476,
0.96825184192900203,
],
},
]
dataset = Dataset.from_list(mapping=dataset)
documents = [
{"document_id": 6862419, "text": "example doc 1"},
{"document_id": 335116, "text": "example doc 2"},
{"document_id": 339186, "text": "example doc 3"},
]
queries = [
{"query_id": 749480, "text": "example query"},
]
documents = Dataset.from_list(mapping=documents)
queries = Dataset.from_list(mapping=queries)
PyLate allows easy retrieval of top documents for a given query set using the trained ColBERT model and Voyager index, simply load the model and init the index:
from pylate import indexes, models, retrieve
model = models.ColBERT(
model_name_or_path="lightonai/colbertv2.0",
)
index = indexes.Voyager(
index_folder="pylate-index",
index_name="index",
override=True,
)
retriever = retrieve.ColBERT(index=index)
Once the model and index are set up, we can add documents to the index using their embeddings and corresponding ids:
documents_ids = ["1", "2", "3"]
documents = [
"document 1 text", "document 2 text", "document 3 text"
]
# Encode the documents
documents_embeddings = model.encode(
documents,
batch_size=32,
is_query=False, # Encoding documents
show_progress_bar=True,
)
# Add the documents ids and embeddings to the Voyager index
index.add_documents(
documents_ids=documents_ids,
documents_embeddings=documents_embeddings,
)
Then we can retrieve the top-k documents for a given set of queries:
queries_embeddings = model.encode(
["query for document 3", "query for document 1"],
batch_size=32,
is_query=True, # Encoding queries
show_progress_bar=True,
)
scores = retriever.retrieve(
queries_embeddings=queries_embeddings,
k=10,
)
print(scores)
Sample Output:
[
[
{"id": "3", "score": 11.266985893249512},
{"id": "1", "score": 10.303335189819336},
{"id": "2", "score": 9.502392768859863},
],
[
{"id": "1", "score": 10.88800048828125},
{"id": "3", "score": 9.950843811035156},
{"id": "2", "score": 9.602447509765625},
],
]
If you only want to use the ColBERT model to perform reranking on top of your first-stage retrieval pipeline without building an index, you can simply use rank function and pass the queries and documents to rerank:
from pylate import rank
queries = [
"query A",
"query B",
]
documents = [
["document A", "document B"],
["document 1", "document C", "document B"],
]
documents_ids = [
[1, 2],
[1, 3, 2],
]
queries_embeddings = model.encode(
queries,
is_query=True,
)
documents_embeddings = model.encode(
documents,
is_query=False,
)
reranked_documents = rank.rerank(
documents_ids=documents_ids,
queries_embeddings=queries_embeddings,
documents_embeddings=documents_embeddings,
)
We welcome contributions! To get started:
pip install "pylate[dev]"
make test
make ruff
make livedoc
You can refer to the library with this BibTeX:
@misc{PyLate,
title={PyLate: Flexible Training and Retrieval for Late Interaction Models},
author={Chaffin, Antoine and Sourty, Raphaël},
url={https://github.com/lightonai/pylate},
year={2024}
}