Open NivekT opened 2 years ago
I am pretty fine with adding tqdm
as one dependency since it's commonly used across ml domains.
I think the only problem is, how we can detect when stop the current progress bar and start a new one. Progress bars are useful in scenarios where the downloaded data is large, so in most cases it will also be chunked before written to disk:
import tqdm
from torchdata.datapipes.iter import StreamReader, IterableWrapper, HttpReader, IterDataPipe, Saver
from typing import Callable, Iterator, TypeVar
D = TypeVar("D")
class ProgressBar(IterDataPipe):
def __init__(self, datapipe: IterDataPipe[D], *, update_fn: Callable[[D], int]) -> None:
self.datapipe = datapipe
self.update_fn = update_fn
def __iter__(self) -> Iterator[D]:
with tqdm.tqdm() as progress_bar:
for data in self.datapipe:
progress_bar.update(self.update_fn(data))
yield data
dp = IterableWrapper(["http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"])
dp = HttpReader(dp)
dp = StreamReader(dp, chunk=1024 * 1024)
dp = ProgressBar(dp, update_fn=lambda data: len(data[1]))
dp = Saver(dp, mode="wb", filepath_fn=lambda url: f"./{url.split('/')[-1]}")
list(dp)
This works fine and is also what I implemented in pytorch/vision@2d03f026d2c04c38b23540f95c71f2006e4fbbcb.
The problem starts when torchdata
wants to support ProgressBar
in a general way. Suppose I add another URL to the input
dp = IterableWrapper(
[
"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
]
)
Now we still only get a single progress bar but this time for both files. That is probably not what we want. We could use another function to be able to detect this:
class ProgressBar(IterDataPipe):
def __init__(
self, datapipe: IterDataPipe[D], *, update_fn: Callable[[D], int], reset_fn: Callable[[D], K] = None
) -> None:
self.datapipe = datapipe
self.update_fn = update_fn
self.reset_fn = reset_fn
def __iter__(self) -> Iterator[D]:
progress_bar = tqdm.tqdm()
sentinel = object()
current_key = sentinel
for data in self.datapipe:
if self.reset_fn:
key = self.reset_fn(data)
if key != current_key and current_key is not sentinel:
progress_bar.close()
progress_bar = tqdm.tqdm()
current_key = key
progress_bar.update(self.update_fn(data))
yield data
progress_bar.close()
Throwing my 2 cents here. Below is a progress bar im using for my own work.
I'm wondering if we can add the ability to display the total. Such as if the data is a tuple of ints/floats.
I really like @pmeier idea of a reset_fn
and update_fn
. Maybe we instead also have total_fn: Optional[Callable[[D], int]]
for extracting the total expected iterations.
Below is what I have, but @pmeier seems more flexible.
T_co = TypeVar("T_co", covariant=True)
@functional_datapipe("progress_bar")
class ProgressBarTracker(dp.iter.IterDataPipe[T_co]):
def __init__(
self,
# The source IterDataPipe to wrap with progress tracking.
source_datapipe:dp.iter.IterDataPipe[T_co],
# A description to display alongside the progress bar. Defaults to None.
desc:Optional[str] = None
) -> None:
"""
A DataPipe that provides a progress bar for iteration over a dataset.
If the returned 'data' is a `Tuple[float,float]`, then we will update the
tqdm to factor this in.
"""
self.source_datapipe = source_datapipe
self.desc = desc
def __iter__(self) -> Iterator[T_co]:
pbar = tqdm(total=None, desc=self.desc, dynamic_ncols=True)
try:
has_total = False
for i,data in enumerate(self.source_datapipe):
if isinstance(data, tuple) \
and len(data)==2 \
and all(isinstance(o,(int,float)) for o in data):
current, total = data
has_total = True
if pbar.total is None:
pbar.total = total
pbar.total = max(pbar.total, total)
pbar.n = current
pbar.refresh()
else:
if not has_total:
pbar.update(i)
yield data
finally:
pbar.close()
🚀 The feature
Add a progress bar to remote DataPipes that will be shown in the terminal to display the status of the operation. We can potentially use
tqdm
orrich
.Previous discussion comes from this comment by @pmeier .
Motivation, pitch
For DataPipes that download data from a remote server, it is useful to have a progress bar to show the download progress so that the users can know if the process is working and what the estimated remaining time is. Relevant DataPipes may be
HttpReader
,GDriveReader
,iopath
,fsspec
andS3
DataPipes.There may be usages for other types of DataPipes as well.
Alternatives
There are many ways to display the status of the operations to users. Feel free to suggest other ideas.
Additional context
No response