huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
19.15k stars 2.67k forks source link

extend the map function so it can wrap around long text that does not fit in the context window #5997

Open siddhsql opened 1 year ago

siddhsql commented 1 year ago

Feature request

I understand dataset provides a map function. This function in turn takes in a callable that is used to tokenize the text on which a model is trained. Frequently this text will not fit within a models's context window. In this case it would be useful to wrap around the text into multiple rows with each row fitting the model's context window. I tried to do it using this code as example which in turn I have borrowed from here:

data = data.map(lambda samples: tokenizer(samples["text"], max_length=tokenizer.model_max_length, truncation=True, stride=4, return_overflowing_tokens=True), batched=True)

but running the code gives me this error:

File "/llm/fine-tune.py", line 117, in <module>
    data = data.map(lambda samples: tokenizer(samples["text"], max_length=tokenizer.model_max_length, truncation=True, stride=4, return_overflowing_tokens=True), batched=True)
  File "/llm/.env/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 580, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/llm/.env/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 545, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/llm/.env/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 3087, in map
    for rank, done, content in Dataset._map_single(**dataset_kwargs):
  File "/llm/.env/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 3480, in _map_single
    writer.write_batch(batch)
  File "/llm/.env/lib/python3.9/site-packages/datasets/arrow_writer.py", line 556, in write_batch
    pa_table = pa.Table.from_arrays(arrays, schema=schema)
  File "pyarrow/table.pxi", line 3798, in pyarrow.lib.Table.from_arrays
  File "pyarrow/table.pxi", line 2962, in pyarrow.lib.Table.validate
  File "pyarrow/error.pxi", line 100, in pyarrow.lib.check_status
pyarrow.lib.ArrowInvalid: Column 1 named input_ids expected length 394 but got length 447

The lambda function I have provided is correctly chopping up long text so it wraps around (and because of this 394 samples become 447 after wrap around) but the dataset map function does not like it.

Motivation

please see above

Your contribution

I'm afraid I don't have much knowledge to help

siddhsql commented 1 year ago

I just noticed the docs say:

If batched is True and batch_size is n > 1, then the function takes a batch of n examples as input and can return a batch with n examples, or with an arbitrary number of examples.

so maybe this is a bug then.

mariosasko commented 1 year ago

All the values in a batch must be of the same length. So one solution is dropping all the input columns:

data = data.map(lambda samples: tokenizer(samples["text"], max_length=tokenizer.model_max_length, truncation=True, stride=4, return_overflowing_tokens=True), batched=True, remove_columns=data.column_names)

Another is padding/transforming the input columns to the tokenizer output's length (447).