facebookresearch / MetaCLIP

ICLR2024 Spotlight: curation/training code, metadata, distribution and pre-trained models for MetaCLIP; CVPR 2024: MoDE: CLIP Data Experts via Clustering
Other
1.17k stars 49 forks source link

Optimize substr_matching with the Aho-Corasick algorithm #45

Closed andimarafioti closed 5 months ago

andimarafioti commented 6 months ago

I optimized the substr_matching algorithm by integrating the Aho-Corasick Algorithm.

the Aho—Corasick algorithm is a string-searching algorithm invented by Alfred V. Aho and Margaret J. Corasick in 1975.[1] It is a kind of dictionary-matching algorithm that locates elements of a finite set of strings (the "dictionary") within an input text. It matches all strings simultaneously.

import json
import timeit
import random
import ahocorasick
from metaclip.substr_matching import spacing

with open("metadata.json") as f:
  metadata = json.load(f)

automaton = None
spaced_metadata = None

def initialize_automaton(metadata):
    automaton = ahocorasick.Automaton()
    for idx, key in enumerate(spaced_metadata):
        automaton.add_word(key, (idx, key))
    automaton.make_automaton()
    return automaton

def optimized_substr_matching(text, metadata):
    global spaced_metadata, automaton
    if spaced_metadata is None:
        spaced_metadata = [f" {entry} " for entry in metadata]
    text = spacing(text)
    if automaton is None:
        automaton = initialize_automaton(metadata)
    matched_entry_ids = set()
    for end_index, (entry_id, original_value) in automaton.iter(text):
        matched_entry_ids.add(entry_id)
    return list(matched_entry_ids)

def original_substr_matching(text, metadata):
    global spaced_metadata
    if spaced_metadata is None:
        spaced_metadata = []
        for entry in metadata:
            spaced_metadata.append(f" {entry} ")
    text = spacing(text)
    matched_entry_ids = []
    for entry_id, entry in enumerate(spaced_metadata):
        if entry in text:
            matched_entry_ids.append(entry_id)
    return matched_entry_ids

def process_texts_original(texts, metadata):
    return [original_substr_matching(text, metadata) for text in texts]

def process_texts_optimized(texts, metadata):
    return [optimized_substr_matching(text, metadata) for text in texts]

# Generate sample metadata and text for testing
sample_text = " ".join(random.choices(metadata, k=10))

# Time the original and optimized functions
original_time = timeit.timeit(lambda: original_substr_matching(sample_text, metadata), number=1)
warm_up_optimized_time = timeit.timeit(lambda: optimized_substr_matching(sample_text, metadata), number=1)
optimized_time = timeit.timeit(lambda: optimized_substr_matching(sample_text, metadata), number=1)

print(f"Original method time: {original_time:.6f} seconds")
print(f"Warm up optimized method time: {warm_up_optimized_time:.6f} seconds")
print(f"Optimized method time: {optimized_time:.6f} seconds")

is_consistent = sorted(original_substr_matching(sample_text, metadata)) == sorted(optimized_substr_matching(sample_text, metadata)) 
print(f"Is the new implementation consistent with the old one?: {is_consistent}")

# Generate a list of short sentences as sample texts
sample_texts = [" ".join(random.choices(metadata, k=10)) for _ in range(1000)]  # 1000 short sentences

# Time the processing of a list of sentences
original_time = timeit.timeit(lambda: process_texts_original(sample_texts, metadata), number=1)
optimized_time = timeit.timeit(lambda: process_texts_optimized(sample_texts, metadata), number=1)

print(f"Original method time for 1000 texts: {original_time:.6f} seconds")
print(f"Optimized method time for 1000 texts: {optimized_time:.6f} seconds")

Which prints in my laptop:

Original method time: 0.086521 seconds
Warm up optimized method time: 0.662693 seconds
Optimized method time: 0.000024 seconds
Is the new implementation consistent with the old one?: True
Original method time for 1000 texts: 51.219827 seconds
Optimized method time for 1000 texts: 0.015514 seconds

Please do consider that the outpout of the new substr_matching will not be ordered now. If you need it to be ordered, you can change the last line to sorted(list(matched_entry_ids))

I appreciate that you want to reduce dependencies for the repo, but considering that the use-case is building massive datasets, this sounds like a sensible improvement to me. What do you think? Is there some context I'm missing?

howardhsu commented 5 months ago

that's a great idea to improve the existing brute-force matching, thx a lot for the contribution. will refactor slightly later.