uber / petastorm

Petastorm library enables single machine or distributed training and evaluation of deep learning models from datasets in Apache Parquet format. It supports ML frameworks such as Tensorflow, Pytorch, and PySpark and can be used from pure Python code.
Apache License 2.0
1.78k stars 285 forks source link

Fix DataLoader iter(dataloader) cannot be called more than once #544

Closed WeichenXu123 closed 4 years ago

WeichenXu123 commented 4 years ago

Issue

PyTorch DataLoader allows iter(datalaoder) to be called more than once. Specifically, the user can define a function:

def train_one_epoch(dataloader, ...):
    for batch in dataloader:
        ...

and use the function with the same dataloader:

with DataLoader(...) as dataloader:
    for epoch in range(num_epochs):
        train_one_epoch(dataloader, ...)

In the current implementation, when for batch in dataloader is called in the second epoch, no batch will be returned.

This PR will enable the usage above.

Fix

When each new iteration starts: reset iterator internal status: including

And add a limit, only after we finish the last iteration, we can start a new iteration because these iterators share the same underlying reader, so they cannot run in parallel.

codecov[bot] commented 4 years ago

Codecov Report

Merging #544 into master will decrease coverage by 0.01%. The diff coverage is 88.37%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #544      +/-   ##
==========================================
- Coverage   86.16%   86.14%   -0.02%     
==========================================
  Files          87       87              
  Lines        4935     4965      +30     
  Branches      786      790       +4     
==========================================
+ Hits         4252     4277      +25     
- Misses        556      560       +4     
- Partials      127      128       +1     
Impacted Files Coverage Δ
petastorm/pytorch.py 92.05% <88.37%> (-2.17%) :arrow_down:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 63c1faf...cd16ca8. Read the comment docs.