pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.13k stars 152 forks source link

Adding Progress Bar to Remote DataPipes #457

Open NivekT opened 2 years ago

NivekT commented 2 years ago

🚀 The feature

Add a progress bar to remote DataPipes that will be shown in the terminal to display the status of the operation. We can potentially use tqdm or rich.

Previous discussion comes from this comment by @pmeier .

Motivation, pitch

For DataPipes that download data from a remote server, it is useful to have a progress bar to show the download progress so that the users can know if the process is working and what the estimated remaining time is. Relevant DataPipes may be HttpReader, GDriveReader, iopath, fsspec and S3 DataPipes.

There may be usages for other types of DataPipes as well.

Alternatives

There are many ways to display the status of the operations to users. Feel free to suggest other ideas.

Additional context

No response

ejguan commented 2 years ago

I am pretty fine with adding tqdm as one dependency since it's commonly used across ml domains.

pmeier commented 2 years ago

I think the only problem is, how we can detect when stop the current progress bar and start a new one. Progress bars are useful in scenarios where the downloaded data is large, so in most cases it will also be chunked before written to disk:

import tqdm
from torchdata.datapipes.iter import StreamReader, IterableWrapper, HttpReader, IterDataPipe, Saver

from typing import Callable, Iterator, TypeVar

D = TypeVar("D")

class ProgressBar(IterDataPipe):
    def __init__(self, datapipe: IterDataPipe[D], *, update_fn: Callable[[D], int]) -> None:
        self.datapipe = datapipe
        self.update_fn = update_fn

    def __iter__(self) -> Iterator[D]:
        with tqdm.tqdm() as progress_bar:
            for data in self.datapipe:
                progress_bar.update(self.update_fn(data))
                yield data

dp = IterableWrapper(["http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"])
dp = HttpReader(dp)
dp = StreamReader(dp, chunk=1024 * 1024)
dp = ProgressBar(dp, update_fn=lambda data: len(data[1]))
dp = Saver(dp, mode="wb", filepath_fn=lambda url: f"./{url.split('/')[-1]}")

list(dp)

This works fine and is also what I implemented in pytorch/vision@2d03f026d2c04c38b23540f95c71f2006e4fbbcb.

The problem starts when torchdata wants to support ProgressBar in a general way. Suppose I add another URL to the input

dp = IterableWrapper(
    [
        "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
        "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
    ]
)

Now we still only get a single progress bar but this time for both files. That is probably not what we want. We could use another function to be able to detect this:

class ProgressBar(IterDataPipe):
    def __init__(
        self, datapipe: IterDataPipe[D], *, update_fn: Callable[[D], int], reset_fn: Callable[[D], K] = None
    ) -> None:
        self.datapipe = datapipe
        self.update_fn = update_fn
        self.reset_fn = reset_fn

    def __iter__(self) -> Iterator[D]:
        progress_bar = tqdm.tqdm()
        sentinel = object()
        current_key = sentinel
        for data in self.datapipe:
            if self.reset_fn:
                key = self.reset_fn(data)
                if key != current_key and current_key is not sentinel:
                    progress_bar.close()
                    progress_bar = tqdm.tqdm()
                current_key = key

            progress_bar.update(self.update_fn(data))
            yield data
        progress_bar.close()
josiahls commented 1 year ago

Throwing my 2 cents here. Below is a progress bar im using for my own work.

I'm wondering if we can add the ability to display the total. Such as if the data is a tuple of ints/floats.

I really like @pmeier idea of a reset_fn and update_fn. Maybe we instead also have total_fn: Optional[Callable[[D], int]] for extracting the total expected iterations.

Below is what I have, but @pmeier seems more flexible.

T_co = TypeVar("T_co", covariant=True)

@functional_datapipe("progress_bar")
class ProgressBarTracker(dp.iter.IterDataPipe[T_co]):
    def __init__(
            self,
            # The source IterDataPipe to wrap with progress tracking.
            source_datapipe:dp.iter.IterDataPipe[T_co], 
            # A description to display alongside the progress bar. Defaults to None.
            desc:Optional[str] = None
        ) -> None:
        """
        A DataPipe that provides a progress bar for iteration over a dataset.

        If the returned 'data' is a `Tuple[float,float]`, then we will update the 
        tqdm to factor this in.
        """
        self.source_datapipe = source_datapipe
        self.desc = desc

    def __iter__(self) -> Iterator[T_co]:
        pbar = tqdm(total=None, desc=self.desc, dynamic_ncols=True)
        try:
            has_total = False
            for i,data in enumerate(self.source_datapipe):
                if isinstance(data, tuple) \
                    and len(data)==2 \
                    and all(isinstance(o,(int,float)) for o in data):
                        current, total = data
                        has_total = True
                        if pbar.total is None:
                            pbar.total = total
                        pbar.total = max(pbar.total, total)
                        pbar.n = current
                        pbar.refresh()
                else:
                    if not has_total:
                        pbar.update(i)
                yield data
        finally:
            pbar.close()