huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
19.31k stars 2.7k forks source link

Add `batch` method to `Dataset` class #7064

Closed lappemic closed 4 months ago

lappemic commented 4 months ago

This PR introduces a new batch method to the Dataset class, aligning its functionality with the IterableDataset.batch() method (implemented in #7054). The implementation uses as well the existing map method for efficient batching of examples.

Key changes:

Closes #7063

Once the approach is approved, i will create the tests and update the documentation.

lhoestq commented 4 months ago

Looks good to me ! :)

you might want to add the map num_proc argument as well, for people who want to make it run faster

lappemic commented 4 months ago

Thanks for the feedback @lhoestq! The last commits include:

WDYT?

lhoestq commented 4 months ago

You can put the documentation in process.mdx :)

HuggingFaceDocBuilderDev commented 4 months ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

lappemic commented 4 months ago

I reset the head to the commit before I added the Dataset.batch() documentation to stream.mdx and instead added the documentation to process.mdx.

github-actions[bot] commented 4 months ago
Show benchmarks PyArrow==8.0.0
Show updated benchmarks! ### Benchmark: benchmark_array_xd.json | metric | read_batch_formatted_as_numpy after write_array2d | read_batch_formatted_as_numpy after write_flattened_sequence | read_batch_formatted_as_numpy after write_nested_sequence | read_batch_unformated after write_array2d | read_batch_unformated after write_flattened_sequence | read_batch_unformated after write_nested_sequence | read_col_formatted_as_numpy after write_array2d | read_col_formatted_as_numpy after write_flattened_sequence | read_col_formatted_as_numpy after write_nested_sequence | read_col_unformated after write_array2d | read_col_unformated after write_flattened_sequence | read_col_unformated after write_nested_sequence | read_formatted_as_numpy after write_array2d | read_formatted_as_numpy after write_flattened_sequence | read_formatted_as_numpy after write_nested_sequence | read_unformated after write_array2d | read_unformated after write_flattened_sequence | read_unformated after write_nested_sequence | write_array2d | write_flattened_sequence | write_nested_sequence | |--------|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| | new / old (diff) | 0.005736 / 0.011353 (-0.005617) | 0.003959 / 0.011008 (-0.007049) | 0.063259 / 0.038508 (0.024751) | 0.030705 / 0.023109 (0.007596) | 0.245706 / 0.275898 (-0.030192) | 0.278766 / 0.323480 (-0.044714) | 0.003354 / 0.007986 (-0.004632) | 0.004246 / 0.004328 (-0.000082) | 0.049346 / 0.004250 (0.045095) | 0.046439 / 0.037052 (0.009386) | 0.257930 / 0.258489 (-0.000559) | 0.295562 / 0.293841 (0.001722) | 0.030529 / 0.128546 (-0.098017) | 0.012465 / 0.075646 (-0.063182) | 0.205595 / 0.419271 (-0.213677) | 0.036319 / 0.043533 (-0.007214) | 0.243872 / 0.255139 (-0.011267) | 0.275834 / 0.283200 (-0.007366) | 0.020330 / 0.141683 (-0.121353) | 1.108337 / 1.452155 (-0.343817) | 1.150406 / 1.492716 (-0.342310) | ### Benchmark: benchmark_getitem\_100B.json | metric | get_batch_of\_1024\_random_rows | get_batch_of\_1024\_rows | get_first_row | get_last_row | |--------|---|---|---|---| | new / old (diff) | 0.113498 / 0.018006 (0.095491) | 0.306654 / 0.000490 (0.306164) | 0.000238 / 0.000200 (0.000038) | 0.000043 / 0.000054 (-0.000012) | ### Benchmark: benchmark_indices_mapping.json | metric | select | shard | shuffle | sort | train_test_split | |--------|---|---|---|---|---| | new / old (diff) | 0.019092 / 0.037411 (-0.018319) | 0.063180 / 0.014526 (0.048654) | 0.078244 / 0.176557 (-0.098313) | 0.126106 / 0.737135 (-0.611030) | 0.078651 / 0.296338 (-0.217687) | ### Benchmark: benchmark_iterating.json | metric | read 5000 | read 50000 | read_batch 50000 10 | read_batch 50000 100 | read_batch 50000 1000 | read_formatted numpy 5000 | read_formatted pandas 5000 | read_formatted tensorflow 5000 | read_formatted torch 5000 | read_formatted_batch numpy 5000 10 | read_formatted_batch numpy 5000 1000 | shuffled read 5000 | shuffled read 50000 | shuffled read_batch 50000 10 | shuffled read_batch 50000 100 | shuffled read_batch 50000 1000 | shuffled read_formatted numpy 5000 | shuffled read_formatted_batch numpy 5000 10 | shuffled read_formatted_batch numpy 5000 1000 | |--------|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| | new / old (diff) | 0.284132 / 0.215209 (0.068923) | 2.781250 / 2.077655 (0.703595) | 1.471864 / 1.504120 (-0.032256) | 1.354661 / 1.541195 (-0.186534) | 1.362839 / 1.468490 (-0.105651) | 0.719126 / 4.584777 (-3.865651) | 2.396969 / 3.745712 (-1.348743) | 2.987924 / 5.269862 (-2.281938) | 1.910555 / 4.565676 (-2.655121) | 0.078612 / 0.424275 (-0.345663) | 0.005170 / 0.007607 (-0.002437) | 0.333876 / 0.226044 (0.107832) | 3.298340 / 2.268929 (1.029412) | 1.853332 / 55.444624 (-53.591292) | 1.551919 / 6.876477 (-5.324557) | 1.585677 / 2.142072 (-0.556395) | 0.802487 / 4.805227 (-4.002741) | 0.134828 / 6.500664 (-6.365837) | 0.041966 / 0.075469 (-0.033503) | ### Benchmark: benchmark_map_filter.json | metric | filter | map fast-tokenizer batched | map identity | map identity batched | map no-op batched | map no-op batched numpy | map no-op batched pandas | map no-op batched pytorch | map no-op batched tensorflow | |--------|---|---|---|---|---|---|---|---|---| | new / old (diff) | 0.992277 / 1.841788 (-0.849511) | 11.626887 / 8.074308 (3.552578) | 9.715623 / 10.191392 (-0.475769) | 0.140306 / 0.680424 (-0.540117) | 0.014528 / 0.534201 (-0.519673) | 0.306247 / 0.579283 (-0.273036) | 0.263067 / 0.434364 (-0.171297) | 0.342325 / 0.540337 (-0.198013) | 0.432299 / 1.386936 (-0.954637) |
PyArrow==latest
Show updated benchmarks! ### Benchmark: benchmark_array_xd.json | metric | read_batch_formatted_as_numpy after write_array2d | read_batch_formatted_as_numpy after write_flattened_sequence | read_batch_formatted_as_numpy after write_nested_sequence | read_batch_unformated after write_array2d | read_batch_unformated after write_flattened_sequence | read_batch_unformated after write_nested_sequence | read_col_formatted_as_numpy after write_array2d | read_col_formatted_as_numpy after write_flattened_sequence | read_col_formatted_as_numpy after write_nested_sequence | read_col_unformated after write_array2d | read_col_unformated after write_flattened_sequence | read_col_unformated after write_nested_sequence | read_formatted_as_numpy after write_array2d | read_formatted_as_numpy after write_flattened_sequence | read_formatted_as_numpy after write_nested_sequence | read_unformated after write_array2d | read_unformated after write_flattened_sequence | read_unformated after write_nested_sequence | write_array2d | write_flattened_sequence | write_nested_sequence | |--------|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| | new / old (diff) | 0.006004 / 0.011353 (-0.005349) | 0.003890 / 0.011008 (-0.007118) | 0.050408 / 0.038508 (0.011900) | 0.031880 / 0.023109 (0.008771) | 0.273114 / 0.275898 (-0.002784) | 0.296653 / 0.323480 (-0.026826) | 0.004569 / 0.007986 (-0.003416) | 0.002831 / 0.004328 (-0.001497) | 0.050032 / 0.004250 (0.045782) | 0.040468 / 0.037052 (0.003415) | 0.284718 / 0.258489 (0.026229) | 0.321754 / 0.293841 (0.027913) | 0.033863 / 0.128546 (-0.094684) | 0.012183 / 0.075646 (-0.063463) | 0.060805 / 0.419271 (-0.358466) | 0.034919 / 0.043533 (-0.008614) | 0.274354 / 0.255139 (0.019215) | 0.293477 / 0.283200 (0.010277) | 0.019418 / 0.141683 (-0.122265) | 1.151571 / 1.452155 (-0.300584) | 1.217174 / 1.492716 (-0.275542) | ### Benchmark: benchmark_getitem\_100B.json | metric | get_batch_of\_1024\_random_rows | get_batch_of\_1024\_rows | get_first_row | get_last_row | |--------|---|---|---|---| | new / old (diff) | 0.097326 / 0.018006 (0.079320) | 0.316277 / 0.000490 (0.315787) | 0.000225 / 0.000200 (0.000025) | 0.000045 / 0.000054 (-0.000009) | ### Benchmark: benchmark_indices_mapping.json | metric | select | shard | shuffle | sort | train_test_split | |--------|---|---|---|---|---| | new / old (diff) | 0.022932 / 0.037411 (-0.014479) | 0.077455 / 0.014526 (0.062929) | 0.088949 / 0.176557 (-0.087608) | 0.129447 / 0.737135 (-0.607688) | 0.093705 / 0.296338 (-0.202634) | ### Benchmark: benchmark_iterating.json | metric | read 5000 | read 50000 | read_batch 50000 10 | read_batch 50000 100 | read_batch 50000 1000 | read_formatted numpy 5000 | read_formatted pandas 5000 | read_formatted tensorflow 5000 | read_formatted torch 5000 | read_formatted_batch numpy 5000 10 | read_formatted_batch numpy 5000 1000 | shuffled read 5000 | shuffled read 50000 | shuffled read_batch 50000 10 | shuffled read_batch 50000 100 | shuffled read_batch 50000 1000 | shuffled read_formatted numpy 5000 | shuffled read_formatted_batch numpy 5000 10 | shuffled read_formatted_batch numpy 5000 1000 | |--------|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| | new / old (diff) | 0.303918 / 0.215209 (0.088709) | 2.973866 / 2.077655 (0.896211) | 1.593165 / 1.504120 (0.089045) | 1.465312 / 1.541195 (-0.075883) | 1.484503 / 1.468490 (0.016013) | 0.731849 / 4.584777 (-3.852928) | 0.953337 / 3.745712 (-2.792375) | 2.887815 / 5.269862 (-2.382047) | 1.923618 / 4.565676 (-2.642058) | 0.080073 / 0.424275 (-0.344202) | 0.005460 / 0.007607 (-0.002148) | 0.359876 / 0.226044 (0.133832) | 3.532251 / 2.268929 (1.263323) | 1.987778 / 55.444624 (-53.456846) | 1.685572 / 6.876477 (-5.190905) | 1.827141 / 2.142072 (-0.314932) | 0.815953 / 4.805227 (-3.989274) | 0.136698 / 6.500664 (-6.363967) | 0.042185 / 0.075469 (-0.033285) | ### Benchmark: benchmark_map_filter.json | metric | filter | map fast-tokenizer batched | map identity | map identity batched | map no-op batched | map no-op batched numpy | map no-op batched pandas | map no-op batched pytorch | map no-op batched tensorflow | |--------|---|---|---|---|---|---|---|---|---| | new / old (diff) | 1.032508 / 1.841788 (-0.809280) | 12.526918 / 8.074308 (4.452610) | 10.202942 / 10.191392 (0.011550) | 0.145920 / 0.680424 (-0.534504) | 0.015643 / 0.534201 (-0.518558) | 0.300465 / 0.579283 (-0.278818) | 0.126786 / 0.434364 (-0.307578) | 0.342885 / 0.540337 (-0.197453) | 0.438139 / 1.386936 (-0.948797) |