huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
19.19k stars 2.68k forks source link

Loading from cache a dataset for LM built from a text classification dataset sometimes errors #3047

Closed sgugger closed 2 years ago

sgugger commented 3 years ago

Describe the bug

Yes, I know, that description sucks. So the problem is arising in the course when we build a masked language modeling dataset using the IMDB dataset. To reproduce (or try since it's a bit fickle).

Create a dataset for masled-language modeling from the IMDB dataset.

from datasets import load_dataset
from transformers import Autotokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased)
imdb_dataset = load_dataset("imdb", split="train")

def tokenize_function(examples):
    return tokenizer(examples["text"])

tokenized_dataset = imdb_dataset.map(
    tokenize_function, batched=True, remove_columns=["text", "label"]
)

chunk_size = 128

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

lm_dataset = tokenized_dataset.map(group_texts, batched=True)

Until now, all is well. The problem comes when you re-execute that code, more specifically:

tokenized_dataset = imdb_dataset.map(
    tokenize_function, batched=True, remove_columns=["text", "label"]
)
lm_dataset = tokenized_dataset.map(group_texts, batched=True)

Try several times if the bug doesn't appear instantly, or just each line at a time, ideally in a notebook/Colab and you should get at some point:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-40-357a56ee3d53> in <module>
----> 1 lm_dataset = tokenized_dataset.map(group_texts, batched=True)

~/git/datasets/src/datasets/arrow_dataset.py in map(self, function, with_indices, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
   1947                 new_fingerprint=new_fingerprint,
   1948                 disable_tqdm=disable_tqdm,
-> 1949                 desc=desc,
   1950             )
   1951         else:

~/git/datasets/src/datasets/arrow_dataset.py in wrapper(*args, **kwargs)
    424         }
    425         # apply actual function
--> 426         out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    427         datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    428         # re-apply format to the output

~/git/datasets/src/datasets/fingerprint.py in wrapper(*args, **kwargs)
    404             # Call actual function
    405 
--> 406             out = func(self, *args, **kwargs)
    407 
    408             # Update fingerprint of in-place transforms + update in-place history of transforms

~/git/datasets/src/datasets/arrow_dataset.py in _map_single(self, function, with_indices, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset, disable_tqdm, desc, cache_only)
   2138             if os.path.exists(cache_file_name) and load_from_cache_file:
   2139                 logger.warning("Loading cached processed dataset at %s", cache_file_name)
-> 2140                 info = self.info.copy()
   2141                 info.features = features
   2142                 return Dataset.from_file(cache_file_name, info=info, split=self.split)

~/git/datasets/src/datasets/info.py in copy(self)
    278 
    279     def copy(self) -> "DatasetInfo":
--> 280         return self.__class__(**{k: copy.deepcopy(v) for k, v in self.__dict__.items()})
    281 
    282 

~/git/datasets/src/datasets/info.py in __init__(self, description, citation, homepage, license, features, post_processed, supervised_keys, task_templates, builder_name, config_name, version, splits, download_checksums, download_size, post_processing_size, dataset_size, size_in_bytes)

~/git/datasets/src/datasets/info.py in __post_init__(self)
    177                 for idx, template in enumerate(self.task_templates):
    178                     if isinstance(template, TextClassification):
--> 179                         labels = self.features[template.label_column].names
    180                         self.task_templates[idx] = TextClassification(
    181                             text_column=template.text_column, label_column=template.label_column, labels=labels

KeyError: 'label'

It seems that when loading the cache, the dataset tries to access some kind of text classification template (which I imagine comes from the original dataset) and to look at a key that has since been removed.

lhoestq commented 2 years ago

This has been fixed in 1.15, let me know if you still have this issue