huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
19.29k stars 2.7k forks source link

Add batching to `IterableDataset` #7054

Closed lappemic closed 4 months ago

lappemic commented 4 months ago

I've taken a try at implementing a batched IterableDataset as requested in issue #6279. This PR adds a new BatchedExamplesIterable class and a .batch() method to the IterableDataset class.

The main changes are:

  1. A new BatchedExamplesIterable that groups examples into batches.
  2. A .batch() method for IterableDataset to easily create batched versions.
  3. Support for shuffling and sharding to work with PyTorch DataLoader and multiple workers.

I'm not sure if this is exactly what you had in mind and also have not fully tested it atm, so I'd really appreciate your feedback. Does this seem like it's heading in the right direction? I'm happy to make any changes or explore different approaches if needed.

Pinging @lhoestq

lhoestq commented 4 months ago

Cool ! Thanks for diving into it :)

Your implementation is great and indeed supports shuffling and batching, you just need to additionally account for state_dict (for dataset checkpointing+resuming)

That being said, I believe the implementation can be made simpler by relying on IterableDataset.map() which already implements all this. Maybe something like


def batch(self, batch_size: int, drop_last_batch: bool = False) -> "IterableDataset":
    def batch(unbatched: dict[str, list]) -> dict[str, list]:
        return {k: [v] for k, v in unbatched}

    return self.map(batch, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch)

And this way no need to reimplement everything !

(my only small concern is that it's not an Arrow-optimized function so it requires the examples to be manipulated as python objects even if the original data is in Arrow format (e.g. when streaming Parquet files) but it's not a big deal and we can see later if we need to optimize this)

lappemic commented 4 months ago

Thanks a lot for the feedback @lhoestq! I definitely could have saved some time looking into it properly first. 😅

Implemented the .batch() method, added a proper docsrtring for documentation, and added tests.

Let me know what you think and if this needs some update.

HuggingFaceDocBuilderDev commented 4 months ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

lappemic commented 4 months ago

Thanks for the feedbak @lhoestq!

Applied it and referenced the batched=True option in the map function and highlighted the difference. Hope i got this right.

github-actions[bot] commented 4 months ago
Show benchmarks PyArrow==8.0.0
Show updated benchmarks! ### Benchmark: benchmark_array_xd.json | metric | read_batch_formatted_as_numpy after write_array2d | read_batch_formatted_as_numpy after write_flattened_sequence | read_batch_formatted_as_numpy after write_nested_sequence | read_batch_unformated after write_array2d | read_batch_unformated after write_flattened_sequence | read_batch_unformated after write_nested_sequence | read_col_formatted_as_numpy after write_array2d | read_col_formatted_as_numpy after write_flattened_sequence | read_col_formatted_as_numpy after write_nested_sequence | read_col_unformated after write_array2d | read_col_unformated after write_flattened_sequence | read_col_unformated after write_nested_sequence | read_formatted_as_numpy after write_array2d | read_formatted_as_numpy after write_flattened_sequence | read_formatted_as_numpy after write_nested_sequence | read_unformated after write_array2d | read_unformated after write_flattened_sequence | read_unformated after write_nested_sequence | write_array2d | write_flattened_sequence | write_nested_sequence | |--------|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| | new / old (diff) | 0.005181 / 0.011353 (-0.006172) | 0.003714 / 0.011008 (-0.007294) | 0.063060 / 0.038508 (0.024552) | 0.030885 / 0.023109 (0.007776) | 0.239060 / 0.275898 (-0.036838) | 0.262480 / 0.323480 (-0.061000) | 0.004103 / 0.007986 (-0.003883) | 0.002696 / 0.004328 (-0.001632) | 0.048706 / 0.004250 (0.044456) | 0.042577 / 0.037052 (0.005525) | 0.249928 / 0.258489 (-0.008561) | 0.283252 / 0.293841 (-0.010589) | 0.029304 / 0.128546 (-0.099242) | 0.012001 / 0.075646 (-0.063646) | 0.204467 / 0.419271 (-0.214804) | 0.035639 / 0.043533 (-0.007894) | 0.243850 / 0.255139 (-0.011289) | 0.261609 / 0.283200 (-0.021590) | 0.018302 / 0.141683 (-0.123381) | 1.096040 / 1.452155 (-0.356115) | 1.135917 / 1.492716 (-0.356800) | ### Benchmark: benchmark_getitem\_100B.json | metric | get_batch_of\_1024\_random_rows | get_batch_of\_1024\_rows | get_first_row | get_last_row | |--------|---|---|---|---| | new / old (diff) | 0.091976 / 0.018006 (0.073970) | 0.296396 / 0.000490 (0.295906) | 0.000203 / 0.000200 (0.000003) | 0.000043 / 0.000054 (-0.000011) | ### Benchmark: benchmark_indices_mapping.json | metric | select | shard | shuffle | sort | train_test_split | |--------|---|---|---|---|---| | new / old (diff) | 0.018405 / 0.037411 (-0.019007) | 0.062470 / 0.014526 (0.047944) | 0.073340 / 0.176557 (-0.103216) | 0.119474 / 0.737135 (-0.617661) | 0.075750 / 0.296338 (-0.220588) | ### Benchmark: benchmark_iterating.json | metric | read 5000 | read 50000 | read_batch 50000 10 | read_batch 50000 100 | read_batch 50000 1000 | read_formatted numpy 5000 | read_formatted pandas 5000 | read_formatted tensorflow 5000 | read_formatted torch 5000 | read_formatted_batch numpy 5000 10 | read_formatted_batch numpy 5000 1000 | shuffled read 5000 | shuffled read 50000 | shuffled read_batch 50000 10 | shuffled read_batch 50000 100 | shuffled read_batch 50000 1000 | shuffled read_formatted numpy 5000 | shuffled read_formatted_batch numpy 5000 10 | shuffled read_formatted_batch numpy 5000 1000 | |--------|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| | new / old (diff) | 0.279586 / 0.215209 (0.064377) | 2.768542 / 2.077655 (0.690887) | 1.449158 / 1.504120 (-0.054962) | 1.328760 / 1.541195 (-0.212435) | 1.336338 / 1.468490 (-0.132152) | 0.732582 / 4.584777 (-3.852195) | 2.325558 / 3.745712 (-1.420154) | 2.898077 / 5.269862 (-2.371784) | 1.893107 / 4.565676 (-2.672569) | 0.078788 / 0.424275 (-0.345487) | 0.005273 / 0.007607 (-0.002335) | 0.334887 / 0.226044 (0.108842) | 3.304173 / 2.268929 (1.035244) | 1.834743 / 55.444624 (-53.609882) | 1.527463 / 6.876477 (-5.349014) | 1.538824 / 2.142072 (-0.603249) | 0.785646 / 4.805227 (-4.019581) | 0.134876 / 6.500664 (-6.365788) | 0.042894 / 0.075469 (-0.032575) | ### Benchmark: benchmark_map_filter.json | metric | filter | map fast-tokenizer batched | map identity | map identity batched | map no-op batched | map no-op batched numpy | map no-op batched pandas | map no-op batched pytorch | map no-op batched tensorflow | |--------|---|---|---|---|---|---|---|---|---| | new / old (diff) | 0.976635 / 1.841788 (-0.865152) | 11.217156 / 8.074308 (3.142848) | 9.616971 / 10.191392 (-0.574421) | 0.127276 / 0.680424 (-0.553148) | 0.014344 / 0.534201 (-0.519857) | 0.301896 / 0.579283 (-0.277387) | 0.259615 / 0.434364 (-0.174749) | 0.340693 / 0.540337 (-0.199645) | 0.429145 / 1.386936 (-0.957791) |
PyArrow==latest
Show updated benchmarks! ### Benchmark: benchmark_array_xd.json | metric | read_batch_formatted_as_numpy after write_array2d | read_batch_formatted_as_numpy after write_flattened_sequence | read_batch_formatted_as_numpy after write_nested_sequence | read_batch_unformated after write_array2d | read_batch_unformated after write_flattened_sequence | read_batch_unformated after write_nested_sequence | read_col_formatted_as_numpy after write_array2d | read_col_formatted_as_numpy after write_flattened_sequence | read_col_formatted_as_numpy after write_nested_sequence | read_col_unformated after write_array2d | read_col_unformated after write_flattened_sequence | read_col_unformated after write_nested_sequence | read_formatted_as_numpy after write_array2d | read_formatted_as_numpy after write_flattened_sequence | read_formatted_as_numpy after write_nested_sequence | read_unformated after write_array2d | read_unformated after write_flattened_sequence | read_unformated after write_nested_sequence | write_array2d | write_flattened_sequence | write_nested_sequence | |--------|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| | new / old (diff) | 0.005534 / 0.011353 (-0.005819) | 0.003795 / 0.011008 (-0.007213) | 0.049761 / 0.038508 (0.011253) | 0.031311 / 0.023109 (0.008202) | 0.276032 / 0.275898 (0.000134) | 0.297316 / 0.323480 (-0.026164) | 0.004396 / 0.007986 (-0.003590) | 0.002693 / 0.004328 (-0.001635) | 0.049025 / 0.004250 (0.044775) | 0.039707 / 0.037052 (0.002654) | 0.284264 / 0.258489 (0.025775) | 0.319962 / 0.293841 (0.026121) | 0.031842 / 0.128546 (-0.096705) | 0.012192 / 0.075646 (-0.063454) | 0.059895 / 0.419271 (-0.359376) | 0.033676 / 0.043533 (-0.009856) | 0.275917 / 0.255139 (0.020778) | 0.292637 / 0.283200 (0.009437) | 0.017992 / 0.141683 (-0.123691) | 1.199329 / 1.452155 (-0.252826) | 1.259083 / 1.492716 (-0.233633) | ### Benchmark: benchmark_getitem\_100B.json | metric | get_batch_of\_1024\_random_rows | get_batch_of\_1024\_rows | get_first_row | get_last_row | |--------|---|---|---|---| | new / old (diff) | 0.092770 / 0.018006 (0.074764) | 0.313363 / 0.000490 (0.312873) | 0.000212 / 0.000200 (0.000013) | 0.000052 / 0.000054 (-0.000003) | ### Benchmark: benchmark_indices_mapping.json | metric | select | shard | shuffle | sort | train_test_split | |--------|---|---|---|---|---| | new / old (diff) | 0.022977 / 0.037411 (-0.014434) | 0.076839 / 0.014526 (0.062314) | 0.088289 / 0.176557 (-0.088267) | 0.128625 / 0.737135 (-0.608510) | 0.089348 / 0.296338 (-0.206990) | ### Benchmark: benchmark_iterating.json | metric | read 5000 | read 50000 | read_batch 50000 10 | read_batch 50000 100 | read_batch 50000 1000 | read_formatted numpy 5000 | read_formatted pandas 5000 | read_formatted tensorflow 5000 | read_formatted torch 5000 | read_formatted_batch numpy 5000 10 | read_formatted_batch numpy 5000 1000 | shuffled read 5000 | shuffled read 50000 | shuffled read_batch 50000 10 | shuffled read_batch 50000 100 | shuffled read_batch 50000 1000 | shuffled read_formatted numpy 5000 | shuffled read_formatted_batch numpy 5000 10 | shuffled read_formatted_batch numpy 5000 1000 | |--------|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| | new / old (diff) | 0.300881 / 0.215209 (0.085672) | 2.946499 / 2.077655 (0.868845) | 1.599686 / 1.504120 (0.095566) | 1.479332 / 1.541195 (-0.061862) | 1.476910 / 1.468490 (0.008420) | 0.720536 / 4.584777 (-3.864241) | 0.944822 / 3.745712 (-2.800890) | 2.771864 / 5.269862 (-2.497998) | 1.886573 / 4.565676 (-2.679103) | 0.078462 / 0.424275 (-0.345813) | 0.005392 / 0.007607 (-0.002215) | 0.354984 / 0.226044 (0.128939) | 3.516449 / 2.268929 (1.247520) | 1.977033 / 55.444624 (-53.467592) | 1.671922 / 6.876477 (-5.204555) | 1.785755 / 2.142072 (-0.356318) | 0.795330 / 4.805227 (-4.009897) | 0.132895 / 6.500664 (-6.367769) | 0.041178 / 0.075469 (-0.034291) | ### Benchmark: benchmark_map_filter.json | metric | filter | map fast-tokenizer batched | map identity | map identity batched | map no-op batched | map no-op batched numpy | map no-op batched pandas | map no-op batched pytorch | map no-op batched tensorflow | |--------|---|---|---|---|---|---|---|---|---| | new / old (diff) | 1.031780 / 1.841788 (-0.810008) | 11.855600 / 8.074308 (3.781292) | 10.245599 / 10.191392 (0.054207) | 0.140649 / 0.680424 (-0.539775) | 0.015332 / 0.534201 (-0.518869) | 0.299402 / 0.579283 (-0.279881) | 0.120007 / 0.434364 (-0.314357) | 0.337770 / 0.540337 (-0.202568) | 0.433679 / 1.386936 (-0.953257) |