uber / petastorm

Petastorm library enables single machine or distributed training and evaluation of deep learning models from datasets in Apache Parquet format. It supports ML frameworks such as Tensorflow, Pytorch, and PySpark and can be used from pure Python code.
Apache License 2.0
1.8k stars 284 forks source link

How to get length of dataset when using pytorch dataloader API? #710

Closed aseembits93 closed 3 years ago

aseembits93 commented 3 years ago

Hi, Thanks for sharing this repo. It's really useful for my work. I was wondering how I could know the size of a dataset I'm loading using the pytorch API.

from petastorm.pytorch import DataLoader
from petastorm import make_reader
with DataLoader(make_reader('file:///tmp/helloworld'), batch_size=64) as train_loader:
    print("length of dataset")

Let me Know how I could do it. Thanks in advance!

selitvin commented 3 years ago

Currently petastorm has no API for letting you know the number of rows that will be returned. It should be possible to add for the make_reader API, but should also be possible to query using standard pyarrow parquet reading tools.

v01dXYZ commented 3 years ago

Supposing your dataset was created with materialize_dataset, you can directly query the _metadata Parquet Metadata File:

import pyarrow.parquet as pq

pmd = pq.read_metadata("/tmp/helloworld/_metadata")

pmd.num_rows
aseembits93 commented 3 years ago

Thanks @v01dXYZ! I'll close the issue.