Open versae opened 1 year ago
Hi! I agree skip
can be inefficient to use in the current state.
To make it fast, we could use "statistics" stored in Parquet metadata and read only the chunks needed to form a dataset.
And thanks to the "datasets-server" project, which aims to store the Parquet versions of the Hub datasets (only the smaller datasets are covered currently), this solution can also be applied to datasets stored in formats other than Parquet. (cc @severo)
@mariosasko do the current parquet files created by the datasets-server already have the required "statistics"? If not, please open an issue on https://github.com/huggingface/datasets-server with some details to make sure we implement it.
Yes, nothing has to be changed on the datasets-server side. What I mean by "statistics" is that we can use the "row_group" metadata embedded in a Parquet file (by default) to fetch the requested rows more efficiently.
Glad to see the feature could be of interest.
I'm sure there are many possible ways to implement this feature. I don't know enough about the datasets-server, but I guess that it is not instantaneous, in the sense that user-owned private datasets might need hours or days until they are ported to the datasets-server (if at all), which could be cumbersome. Having optionally that information in the dataset_infos.json
file would make it easier for users to control the skip process a bit.
re: statistics:
>>> import pyarrow.parquet as pq
>>> import hffs
>>> fs = hffs.HfFileSystem("glue", repo_type="dataset", revision="refs/convert/parquet")
>>> metadata = pq.read_metadata("ax/glue-test.parquet", filesystem=fs)
>>> metadata
<pyarrow._parquet.FileMetaData object at 0x7f4537cec400>
created_by: parquet-cpp-arrow version 7.0.0
num_columns: 4
num_rows: 1104
num_row_groups: 2
format_version: 1.0
serialized_size: 2902
>>> metadata.row_group(0)
<pyarrow._parquet.RowGroupMetaData object at 0x7f45564bcbd0>
num_columns: 4
num_rows: 1000
total_byte_size: 164474
>>> metadata.row_group(1)
<pyarrow._parquet.RowGroupMetaData object at 0x7f455005c400>
num_columns: 4
num_rows: 104
total_byte_size: 13064
user-owned private datasets might need hours or days until they are ported to the datasets-server (if at all)
private datasets are not supported yet (https://github.com/huggingface/datasets-server/issues/39)
@versae Dataset.push_to_hub
writes shards in Parquet, so this solution would also work for such datasets (immediately after the push).
@mariosasko that is right. However, there are still a good amount of datasets for which the shards are created manually. In our very specific case, we create medium-sized datasets (rarely over 100-200GB) of both text and audio, we prepare the shards by hand and then upload then. It would be great to have immediate access to this download skipping feature for them too.
From looking at Arrow's source, it seems Parquet stores metadata at the end, which means one needs to iterate over a Parquet file's data before accessing its metadata. We could mimic Dask to address this "limitation" and write metadata in a _metadata
/_common_metadata
file in to_parquet
/push_to_hub
, which we could then use to optimize reads (if present). Plus, it's handy that PyArrow can also parse these metadata files.
So if Parquet metadata needs to be in its own file anyway, why not implement this skipping feature by storing the example counts per shard in dataset_infos.json
? That would allow:
.push_to_hub()
A proper Parquet metadata file could still be created and "overwrite" the dataset_infos.json
info in the datasets-server.
Feature request
Add extra information to the
dataset_infos.json
file to include the number of samples/examples in each shard, for example in a new fieldnum_examples
alongsidenum_bytes
. The.skip()
function could use this information to ignore the download of a shard when in streaming mode, which AFAICT it should speed up the skipping process.Motivation
When resuming from a checkpoint after a crashed run, using
dataset.skip()
is very convenient to recover the exact state of the data and to not train again over the same examples (assuming same seed, no shuffling). However, I have noticed that for audio datasets in streaming mode this is very costly in terms of time, as shards need to be downloaded every time before skipping the right number of examples.Your contribution
I took a look already at the code, but it seems a change like this is way deeper than I am able to manage, as it touches the library in several parts. I could give it a try but might need some guidance on the internals.