huggingface / tokenizers

šŸ’„ Fast State-of-the-Art Tokenizers optimized for Research and Production
https://huggingface.co/docs/tokenizers
Apache License 2.0
8.93k stars 777 forks source link

Memory Leak in encode_batch Function #1421

Closed Atakey closed 7 months ago

Atakey commented 9 months ago

Problem

I have encountered a memory leak issue in the batch_encode function https://github.com/huggingface/tokenizers/blob/11462596d11d886e501091e28acbe1174385087a/bindings/python/src/tokenizer.rs#L1015 This leak becomes more apparent with an increasing number of text inputs, suggesting a correlation between the size of the leak and the number of texts processed.

Maybe this problem is not called a memory leak, but it is indeed a problem.

Steps to Reproduce

  1. Prepare a large dataset of text inputs where the total size in memory, more is better. For example, 20 milllion sentences is ok.
  2. Use the batch_encode function to process this dataset.
  3. Observe the memory usage. When using batch_encode(batch_sentences), there is a noticeable increase in memory usage, which does not happen when using batch_encode(["".join(sentence) for sentence in batch_sentences]).

Simple code to produce

import math
import os
import time
from itertools import islice
from typing import List

import tqdm
from tokenizers import Tokenizer as TokenizerFast

def copy_batch_sentences(batch_sentences: List[str]):
    return ["".join(sentence) for sentence in batch_sentences]

def lazy_groups_of(iterable, group_size: int):
    """
    Takes an iterable and batches the individual instances into lists of the
    specified size. The last list may be smaller if there are instances left over.
    """
    iterator = iter(iterable)
    while True:
        s = list(islice(iterator, group_size))
        if len(s) > 0:
            yield s
        else:
            break

def load_sentences_from_file(line_num: int = 20_000_000):
    pass

def memory_leak_test_example():
    batch_size = 20000
    sentences = load_sentences_from_file()
    # this will cause memory leak
    for batch_sentences in tqdm.tqdm(lazy_groups_of(sentences, batch_size),
                                total=math.ceil(len(sentences) / batch_size)):
        tokenizer.encode_batch(
            batch_sentences ,
        )

def memory_not_leak_test_example():
    batch_size = 20000
    sentences = load_sentences_from_file()
    for batch_sentences in tqdm.tqdm(lazy_groups_of(sentences, batch_size),
                                total=math.ceil(len(sentences) / batch_size)):
        tokenizer.encode_batch(
            copy_batch_sentences(batch_sentences),
        )

Potential Cause

It appears that the issue might be related to how memory addresses of text data are managed when interfaced between Python and Rust. When passing batch_text directly to batch_encode, Rust functions seem to receive Python-allocated memory addresses directly, which they might not manage correctly. In contrast, using ["".join(text) for text in batch_text] creates new string objects in Python, which are then passed to Rust, possibly allowing Rust to manage these memory locations more effectively.

Suggested Solution

A possible solution could involve ensuring that Rust takes full control of the memory management of the strings it processes. This might require modifications in the batch_encode function to deep copy the string data from Python before processing, allowing Rust to handle memory allocation and deallocation independently.

Note: The above part of the content is generated by gpt-4.

ArthurZucker commented 9 months ago

Hey @Atakey when you say that the above part is generated by gpt-4, which part do you refer to? This otherwise sounds interesting, could you share a full reproducer (in this case the sentence are not really created)

Atakey commented 9 months ago

Hey @Atakey when you say that the above part is generated by gpt-4, which part do you refer to? This otherwise sounds interesting, could you share a full reproducer (in this case the sentence are not really created)

I found this issue in encode_batch, but since I am not familiar with the rust language, I discussed it with gpt-4 and let it assist in generating the issue.

You can randomly generate sentences, you shouldn't use like ['Here is an example.'] * 20_000_000. And randomly generation would be slower than loading sentences from a file.

def generate_random_sentences(size: int, length: int = 10, word_length: int = 3):
    """
    Generates a specified number of random strings
    Parameters:
        size (int): The number of random strings
        length (int): The length of the random string, defaults to 10
        word_length (int): The length of the word in a single random string, defaults to 3
    Returns: Iterator for the generated random string
    """
    import numpy as np
    # Generate the random string
    sentences_chars = np.random.randint(97, 123, (size, length, word_length))
    return [" ".join("".join(chr(char) for char in chars) for chars in sentence_chars) for sentence_chars in
                sentences_chars]
github-actions[bot] commented 8 months ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

platers commented 6 months ago

I also ran into this. When batching more than one input memory leaks. I worked around it by using multiprocessing to call many instances of the tokenizer on single inputs, which isn't that much slower than the intended batching.

Narsil commented 6 months ago

I'm confused. The example doesn't seem to show any memory leak issue, in fact, there's less memory for the "leak" version.

Also please complete the script all the way to make it reproducible.