UKPLab / sentence-transformers

Multilingual Sentence & Image Embeddings with BERT
https://www.SBERT.net
Apache License 2.0
14.33k stars 2.39k forks source link

[Question] Hard negative mining #2697

Open austinmw opened 1 month ago

austinmw commented 1 month ago

Hi, does sentence-transformers happen to have any utility methods to generate an expanded dataset with hard negatives from an input dataset and model?

tomaarsen commented 1 month ago

Hello!

It does not currently, although this would be a very valuable addition! I'd be very happy to receive a pull request for this. In the meantime, you can use the semantic_search utility function to get the top K matching entries for all of your queries in your corpus. You can then manually filter out the true positives and/or use a CrossEncoder to further filter away the false negatives, so you're left with only hard negatives, i.e. samples that appear similar, but should actually not be retrieved for your query.

Please let me know if that sounds feasible!

austinmw commented 1 month ago

Hi @tomaarsen, thanks for your reply!

I could probably submit a PR for this, can you give me some initial feedback on the following before I submit a draft?

import torch
import numpy as np
from sentence_transformers import SentenceTransformer, util, CrossEncoder
from datasets import Dataset
from tqdm.auto import tqdm

# Synthetic data for testing
def create_sample_pairs_dataset():
    data = {
        'anchor': [
            "What are the health benefits of regular exercise?",
            "How to improve your time management skills?",
            "What is the capital of France?",
            "How does photosynthesis work in plants?",
            "What are the symptoms of a common cold?",
            "How to cook a perfect steak?",
            "What is the importance of cybersecurity?",
            "How to start investing in stocks?",
            "What are the best practices for remote work?",
            "How does blockchain technology work?",
            "What are the stages of the water cycle?",
            "How to learn a new language quickly?",
            "What are the effects of climate change?",
            "How to maintain a healthy work-life balance?",
            "What are the benefits of meditation?",
            "How to build a successful startup?"
        ],
        'positive': [
            "Regular exercise improves cardiovascular health, strengthens muscles, and enhances mental well-being.",
            "Improving time management skills involves setting priorities, using tools like calendars and planners, and avoiding procrastination.",
            "The capital of France is Paris, a major European city known for its art, fashion, and culture.",
            "Photosynthesis is the process by which plants convert sunlight into chemical energy, producing oxygen as a byproduct.",
            "Common cold symptoms include a runny nose, sore throat, coughing, sneezing, and congestion.",
            "To cook a perfect steak, season it well, use a hot pan, and let it rest before serving.",
            "Cybersecurity is crucial for protecting sensitive information from cyber threats and maintaining privacy.",
            "Starting to invest in stocks requires understanding the market, researching companies, and considering risks.",
            "Best practices for remote work include setting up a dedicated workspace, maintaining regular hours, and communicating effectively.",
            "Blockchain technology is a decentralized ledger system that ensures secure and transparent transactions.",
            "The stages of the water cycle include evaporation, condensation, precipitation, and collection.",
            "Learning a new language quickly involves consistent practice, immersion, and using language learning apps.",
            "Climate change effects include rising temperatures, melting ice caps, and increased frequency of extreme weather events.",
            "Maintaining a healthy work-life balance requires setting boundaries, prioritizing self-care, and managing time effectively.",
            "Meditation benefits include reduced stress, improved concentration, and enhanced emotional health.",
            "Building a successful startup involves identifying a market need, creating a solid business plan, and securing funding."
        ]
    }
    df = pd.DataFrame(data)
    dataset = Dataset.from_pandas(df)
    return dataset

def add_hard_negatives(dataset, embedding_model_name, cross_encoder_name, range_min=1, threshold=0.5, batch_size=8, use_gpu=True, negative_number=3):
    """
    Add hard negatives to a dataset of (anchor, positive) pairs to create (anchor, positive, negative) triplets.

    Args:
        dataset (Dataset): The dataset containing (anchor, positive) pairs.
        embedding_model_name (str): Name of the embedding model to use.
        cross_encoder_name (str): Name of the cross encoder model to use.
        range_min (int): Minimum rank of the closest matches to consider as negatives (e.g., if 2, the top 1 closest matches are not used).
        threshold (float): Threshold for CrossEncoder similarity score.
        batch_size (int): Batch size for processing.
        use_gpu (bool): Whether to use GPU for searching.
        negative_number (int): Number of negatives to sample.

    Returns:
        Dataset: A dataset containing (anchor, positive, negative) triplets.
    """
    device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")
    model = SentenceTransformer(embedding_model_name, device=device)
    cross_encoder = CrossEncoder(cross_encoder_name, device=device)

    # Calculate the value of k
    k = negative_number + range_min + 1

    # Combine anchor and positive sentences to get unique corpus
    anchors = dataset['anchor']
    positives = dataset['positive']
    sentences = positives  # Use only positives for negatives sampling
    embeddings = model.encode(sentences, convert_to_tensor=True, device=device)

    # Find top K matching entries for all queries in the corpus in batches
    triplets_data = []
    for start_idx in tqdm(range(0, len(anchors), batch_size), desc="Batches"):
        end_idx = min(start_idx + batch_size, len(anchors))
        batch_embeddings = model.encode(anchors[start_idx:end_idx], convert_to_tensor=True, device=device)
        for idx, query_embedding in enumerate(batch_embeddings):
            hits = util.semantic_search(query_embedding, embeddings, top_k=k)[0]

            # Filter out the true positives
            true_positive_idx = start_idx + idx
            hits = [hit for hit in hits if sentences[hit['corpus_id']] != positives[true_positive_idx]]

            if not hits:
                continue  # Skip if no potential negatives are found

            # Use CrossEncoder to filter false negatives
            cross_encoder_scores = cross_encoder.predict([[anchors[start_idx + idx], sentences[hit['corpus_id']]] for hit in hits])

            # Apply threshold for true similarity
            filtered_hits = [hits[i] for i in range(len(hits)) if cross_encoder_scores[i] < threshold]

            # Sample negatives from the filtered hits
            filtered_hits = filtered_hits[range_min:]
            if len(filtered_hits) > negative_number:
                filtered_hits = np.random.choice(filtered_hits, negative_number, replace=False)

            if len(filtered_hits) == 0:
                continue  # Skip if no hard negatives found

            # Create triplets (anchor, positive, negative)
            positive = positives[true_positive_idx]
            for hit in filtered_hits:
                negative = sentences[hit['corpus_id']]
                triplets_data.append({
                    'anchor': anchors[start_idx + idx],
                    'positive': positive,
                    'negative': negative
                })

    if len(triplets_data) == 0:
        raise ValueError("No triplets were generated. Please check the parameters and dataset.")

    triplets_dataset = Dataset.from_pandas(pd.DataFrame(triplets_data))

    return triplets_dataset

# Example usage
if __name__ == "__main__":
    embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
    cross_encoder_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
    threshold = 0.5  # Default threshold, can be adjusted
    batch_size = 8  # Batch size for processing
    range_min = 2  # Minimum rank of the closest matches to consider as negatives
    use_gpu = True  # Use GPU by default
    negative_number = 3  # Number of negatives to sample

    # Create sample pairs dataset
    sample_pairs_dataset = create_sample_pairs_dataset()

    # Generate hard negatives
    hard_negative_dataset = add_hard_negatives(
        dataset=sample_pairs_dataset,
        embedding_model_name=embedding_model_name,
        cross_encoder_name=cross_encoder_name,
        range_min=range_min,
        threshold=threshold,
        batch_size=batch_size,
        use_gpu=use_gpu,
        negative_number=negative_number
    )
tomaarsen commented 1 month ago

@austinmw This is a great starting point! I actually have quite a lot of ideas for expanding this helper function. Would you be okay with me taking over development based on this and opening a PR myself?

jturner116 commented 1 month ago

Was looking into this to try to reproduce something like this for the TREC RAG version of MSMARCO, this functionality would be very appreciated :D

austinmw commented 1 month ago

@tomaarsen absolutely, thanks!