huggingface / tokenizers

💥 Fast State-of-the-Art Tokenizers optimized for Research and Production
https://huggingface.co/docs/tokenizers
Apache License 2.0
8.92k stars 777 forks source link

Training a tokenizer with limited memory #1460

Closed arxyzan closed 5 months ago

arxyzan commented 7 months ago

Hi, I'm trying to train a new tokenizer using a Llama fast tokenizer. I have followed the instructions at https://huggingface.co/learn/nlp-course/chapter6/2#training-a-new-tokenizer. The problem is that even using batch iteration, I get OOM error and the kernel crashes and I've seen people have had such problem too and found no workaround.

My solution to this problem is to train the tokenizer in multiple steps on different shards of the dataset. (Not sure if it results into the same tokenizer compared to training in one pass!)

Reproducible Code

The reproducible code is as follows:

from datasets import load_dataset
from transformers import AutoTokenizer

TOKENIZER_PATH = "meta-llama/Llama-2-7b-hf"
VOCAB_SIZE = 42000
BATCH_SIZE = 1000
SHARD_SIZE = 100_000  # Train the tokenizer on SHARD_SIZE samples in each pass

# Load dataset to calculate `num_shards`
raw_dataset = load_dataset("wikimedia/wikipedia", "20231101.fa", split="train")
num_shards = len(raw_dataset) // SHARD_SIZE
start_index = 0

for i in range(num_shards):
    dataset = load_dataset(
        "wikimedia/wikipedia",
        "20231101.fa",
        split=f"train[{start_index}:{start_index + SHARD_SIZE}]",
    )
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)

    def get_training_corpus(d, batch_size=BATCH_SIZE):
        for batch in d.iter(batch_size=batch_size):
            yield batch["text"]

    data_iterator = get_training_corpus(dataset)

    tokenizer = tokenizer.train_new_from_iterator(
        data_iterator,
        vocab_size=VOCAB_SIZE,
        length=len(dataset),
    )

    start_index += SHARD_SIZE

    tokenizer.save_pretrained(TOKENIZER_PATH)
    print(f"Tokenizer trained for data shard #{i} and saved to `{TOKENIZER_PATH}`")

    # Delete objects to free up RAM
    del dataset
    del tokenizer  # Does this also delete the object within the Rust backend?
    del data_iterator

Problem

The problem is that still the RAM allocation gradually increases and depending on the full dataset size, OOM error can still happen. The object deletion at the end of each loop was meant to reduce memory on each pass but seems to have no effect since the tokenizer object and trainer reside in the Rust backend. I think if there'd be a way to also delete objects in the Rust backend from Python code, the problem would not arise anymore, OR maybe there is another workaround for this which I don't know!

Narsil commented 6 months ago

Python is already able to clean stuff up on its own, and yes, the rust backend also cleans up after itself (unless there's a bug).

What's more likely is that the various datasets contain different data, namely different sentence lengths which trigger different kind of memory usage.

arxyzan commented 6 months ago

Thanks @Narsil, Another question; Does the following scenarios result in the same tokenizers:

  1. Training the tokenizer in one pass on the whole data (considering there's enough memory)
  2. Training the tokenizer in multiple passes on splitted shards of the dataset (sharding the dataset into N portions and training the tokenizer on these shards one after the other)

Thanks in advance!

github-actions[bot] commented 5 months ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.