huggingface / optimum

🚀 Accelerate training and inference of 🤗 Transformers and 🤗 Diffusers with easy to use hardware optimization tools
https://huggingface.co/docs/optimum/main/
Apache License 2.0
2.56k stars 462 forks source link

get_wikitext2 has bug #2020

Open alex-ber opened 1 month ago

alex-ber commented 1 month ago

System Info

optimum version 1.21.4 (latest)
# Use the official Python image from the Docker Hub
FROM public.ecr.aws/docker/library/python:3.10-slim

Who can help?

No response

Information

Tasks

Reproduction (minimal, reproducible, runnable)

from optimum.gptq.data import get_wikitext2
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
get_wikitext2(tokenizer=tokenizer, nsamples=128, seqlen=32, split="train")

Produce warning:

Token indices sequence length is longer than the specified maximum sequence length for this model (73218 > 2048). Running this sequence through the model will result in indexing errors

Expected behavior

This is proposed fix:


def get_wikitext2(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
    if split == "train":
        data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    elif split == "validation":
        data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    ## length of 288059 should be enough
    #text = "".join([" \n" if s == "" else s for s in data["text"][:1000]])

    dataset = []
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(data) - 1)
            text = data[i]["text"]
            if len(tokenizer.tokenize(text)) >= seqlen:
                enc = tokenizer(text, return_tensors="pt")
                break
        i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = enc.input_ids[:, i:j]
        attention_mask = torch.ones_like(inp)
        dataset.append({"input_ids": inp, "attention_mask": attention_mask})
    return dataset

Inspired by get_c4`` andget_c4_new```.

No warning is produced.

IlyasMoutawwakil commented 1 month ago

@SunMarc is there a reason why get_wikitext2 is different than the other methods ?

SunMarc commented 1 month ago

Not sure. This was something TheBloke coded back then.Maybe this is because data[i]["text"] is pretty long so it takes to while to find a text < seqlen ?

Token indices sequence length is longer than the specified maximum sequence length for this model (73218 > 2048). Running this sequence through the model will result in indexing errors

This does not happen as we are slicing the tokenized data after:

        i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = enc.input_ids[:, i:j]
        attention_mask = torch.ones_like(inp)