Yes, I know, that description sucks. So the problem is arising in the course when we build a masked language modeling dataset using the IMDB dataset. To reproduce (or try since it's a bit fickle).
Create a dataset for masled-language modeling from the IMDB dataset.
from datasets import load_dataset
from transformers import Autotokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased)
imdb_dataset = load_dataset("imdb", split="train")
def tokenize_function(examples):
return tokenizer(examples["text"])
tokenized_dataset = imdb_dataset.map(
tokenize_function, batched=True, remove_columns=["text", "label"]
)
chunk_size = 128
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
# Compute length of concatenated texts
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the last chunk if it's smaller than chunk_size
total_length = (total_length // chunk_size) * chunk_size
# Split by chunks of max_len.
result = {
k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
for k, t in concatenated_examples.items()
}
# Create a new labels column
result["labels"] = result["input_ids"].copy()
return result
lm_dataset = tokenized_dataset.map(group_texts, batched=True)
Until now, all is well. The problem comes when you re-execute that code, more specifically:
Try several times if the bug doesn't appear instantly, or just each line at a time, ideally in a notebook/Colab and you should get at some point:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-40-357a56ee3d53> in <module>
----> 1 lm_dataset = tokenized_dataset.map(group_texts, batched=True)
~/git/datasets/src/datasets/arrow_dataset.py in map(self, function, with_indices, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
1947 new_fingerprint=new_fingerprint,
1948 disable_tqdm=disable_tqdm,
-> 1949 desc=desc,
1950 )
1951 else:
~/git/datasets/src/datasets/arrow_dataset.py in wrapper(*args, **kwargs)
424 }
425 # apply actual function
--> 426 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
427 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
428 # re-apply format to the output
~/git/datasets/src/datasets/fingerprint.py in wrapper(*args, **kwargs)
404 # Call actual function
405
--> 406 out = func(self, *args, **kwargs)
407
408 # Update fingerprint of in-place transforms + update in-place history of transforms
~/git/datasets/src/datasets/arrow_dataset.py in _map_single(self, function, with_indices, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset, disable_tqdm, desc, cache_only)
2138 if os.path.exists(cache_file_name) and load_from_cache_file:
2139 logger.warning("Loading cached processed dataset at %s", cache_file_name)
-> 2140 info = self.info.copy()
2141 info.features = features
2142 return Dataset.from_file(cache_file_name, info=info, split=self.split)
~/git/datasets/src/datasets/info.py in copy(self)
278
279 def copy(self) -> "DatasetInfo":
--> 280 return self.__class__(**{k: copy.deepcopy(v) for k, v in self.__dict__.items()})
281
282
~/git/datasets/src/datasets/info.py in __init__(self, description, citation, homepage, license, features, post_processed, supervised_keys, task_templates, builder_name, config_name, version, splits, download_checksums, download_size, post_processing_size, dataset_size, size_in_bytes)
~/git/datasets/src/datasets/info.py in __post_init__(self)
177 for idx, template in enumerate(self.task_templates):
178 if isinstance(template, TextClassification):
--> 179 labels = self.features[template.label_column].names
180 self.task_templates[idx] = TextClassification(
181 text_column=template.text_column, label_column=template.label_column, labels=labels
KeyError: 'label'
It seems that when loading the cache, the dataset tries to access some kind of text classification template (which I imagine comes from the original dataset) and to look at a key that has since been removed.
Describe the bug
Yes, I know, that description sucks. So the problem is arising in the course when we build a masked language modeling dataset using the IMDB dataset. To reproduce (or try since it's a bit fickle).
Create a dataset for masled-language modeling from the IMDB dataset.
Until now, all is well. The problem comes when you re-execute that code, more specifically:
Try several times if the bug doesn't appear instantly, or just each line at a time, ideally in a notebook/Colab and you should get at some point:
It seems that when loading the cache, the dataset tries to access some kind of text classification template (which I imagine comes from the original dataset) and to look at a key that has since been removed.