NVIDIA / DALI

A GPU-accelerated library containing highly optimized building blocks and an execution engine for data processing to accelerate deep learning training and inference applications.
https://docs.nvidia.com/deeplearning/dali/user-guide/docs/index.html
Apache License 2.0
5.09k stars 615 forks source link

External WebDataset resizing #3658

Closed austinmw closed 2 years ago

austinmw commented 2 years ago

Hi, apologies for creating multiple issues for questions recently! However I noticed something in the external webdataset example that could be helpful to others as well.

It seems that the example only works as-is because the input data is MNIST which is always 28x28. If I try to swap this to use a dataset that has variable image sizes it breaks.

I think that's because it's trying to do batched resizing when instead it should do single image resizing followed by batched augmentations for the rest.

For a solution, which would be best (or other)?

  1. pre-resize images before creating webdataset
  2. Have webdataset decode and resize instead of dali
  3. Have dali somehow do single-image resize followed by other augs batched (not sure how to do this?)

Thanks!

JanuszL commented 2 years ago

Hi @austinmw,

apologies for creating multiple issues for questions recently!

No problem. We are happy to help.

If I try to swap this to use a dataset that has variable image sizes it breaks.

It is not expected, maybe there is an error in the example. Can you provide a self-contained repro we can run on our side?

I think that's because it's trying to do batched resizing when instead it should do single image resizing followed by batched augmentations for the rest.

It should not be the case, the external source operator can accept a batch of samples with irregular shapes. All DALI operators can process batches of o uniform data which is the main advantage over other frameworks, so no additional preprocessing is needed. The getting started example shows that DALI processes a batch of images, each with a different resolution.

austinmw commented 2 years ago

Ah thanks that's great to here. I guess I must be doing something incorrectly. Below is an example I have. If I set batch_size=1 it seems to run, but with batch_size > 1 it errors.

# Download and unzip dogs classification dataset
!kaggle competitions download -c dog-breed-identification -p dog-breed-identification
!cd dog-breed-identification && unzip -q dog-breed-identification.zip

import os
import random
import numpy as np
import pandas as pd
import webdataset as wds
from tqdm.autonotebook import tqdm
import nvidia.dali as dali
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import DALIGenericIterator, DALIClassificationIterator, LastBatchPolicy

# Create wds shards
def write_dogs_wd(root_dir='./dog-breed-identification', shard_dir='./shards', maxsize=1e9, maxcount=100000):

    assert maxsize > 10000000
    assert maxcount < 1000000

    if not os.path.exists(shard_dir):
        os.makedirs(shard_dir)

    root_dir = root_dir
    image_dir = os.path.join(root_dir, 'train/')
    img_list = pd.read_csv(os.path.join(root_dir, 'labels.csv'))
    idx2breed=list(img_list['breed'].unique())
    breed2idx = {b: i for i, b in enumerate(idx2breed)}
    nimages = img_list.shape[0]
    print(f'nimages: {nimages}')
    indexes = list(range(nimages))
    random.shuffle(indexes)        

    # This is the output pattern under which we write shards
    pattern = os.path.join(shard_dir, f'train-%06d.tar')
    with wds.ShardWriter(pattern, maxsize=int(maxsize), maxcount=int(maxcount), compress=False) as sink:

        for i in tqdm(indexes):
            key = f'{i:07d}'                        
            img_row = img_list.iloc[i]
            label = breed2idx[img_row['breed']]
            with open(os.path.join(image_dir, f"{img_row['id']}.jpg"), 'rb') as stream:
                image = stream.read()

            # Construct a sample
            sample = {'__key__': key, 'jpg': image, 'cls': label}

            # Write the sample to the sharded tar archives
            sink.write(sample)

write_dogs_wd(maxcount=512)

# DALI
batch_size = 64
num_workers = 8

# The function below is used to later randomize the output from the dataset. 
# The samples are first stored in a prefetch buffer, 
# and then they're randomly yielded in a generator and replaced by a new sample.
def buffered_shuffle(generator_factory, initial_fill, seed):
    def buffered_shuffle_generator():
        nonlocal generator_factory, initial_fill, seed
        generator = generator_factory()
        # The buffer size must be positive
        assert(initial_fill > 0)

        # The buffer that will hold the randomized samples
        buffer = []

        # The random context for preventing side effects
        random_context = random.Random(seed)

        try:
            while len(buffer) < initial_fill: # Fills in the random buffer
                buffer.append(next(generator))

            while True: # Selects a random sample from the buffer and then fills it back in with a new one
                idx = random_context.randint(0, initial_fill-1)

                yield buffer[idx]
                buffer[idx] = None
                buffer[idx] = next(generator)

        except StopIteration: # When the generator runs out of the samples flushes our the buffer
            random_context.shuffle(buffer)

            while buffer:
                if buffer[-1] != None: # Prevents the one sample that was not filled from being duplicated
                    yield buffer[-1]
                buffer.pop()
    return buffered_shuffle_generator

# The next function is used for padding the last batch with the last sample, 
# in order to make it the same size as all the other ones.
def last_batch_padding(generator_factory, batch_size):
    def last_batch_padding_generator():
        nonlocal generator_factory, batch_size
        generator = generator_factory()
        in_batch_idx = 0
        last_item = None
        try:
            while True: # Keeps track of the last sample and the sample number mod batch_size
                if in_batch_idx >= batch_size:
                    in_batch_idx -= batch_size
                last_item = next(generator)
                in_batch_idx += 1
                yield last_item
        except StopIteration: # Repeats the last sample the necessary number of times
            while in_batch_idx < batch_size:
                yield last_item
                in_batch_idx += 1
    return last_batch_padding_generator

# The final function collects all the data into batches in order to be able to have a variable length batch for the last sample
def collect_batches(generator_factory, batch_size):
    def collect_batches_generator():
        nonlocal generator_factory, batch_size
        generator = generator_factory()
        batch = []
        try:
            while True:
                batch.append(next(generator))
                if len(batch) == batch_size:
                    # Converts tuples of samples into tuples of batches of samples
                    yield tuple(map(list, zip(*batch)))
                    batch = []
        except StopIteration:
            if batch is not []:
                # Converts tuples of samples into tuples of batches of samples
                yield tuple(map(list, zip(*batch)))
    return collect_batches_generator

def read_webdataset(
    urls, 
    extensions=None,
    random_shuffle=False, 
    initial_fill=512, 
    seed=0,
    pad_last_batch=False,
    read_ahead=False,
    cycle="quiet"
):

    # Parsing the input data
    assert(cycle in {"quiet", "raise", "no"})
    if extensions == None:
        extensions = ';'.join(["jpg", "jpeg", "img", "image", "pbm", "pgm", "png"]) # All supported image formats
    if type(extensions) == str:
        extensions = (extensions,)

    # For later information for batch collection and padding
    max_batch_size = dali.pipeline.Pipeline.current().max_batch_size

    if isinstance(urls, str) and 's3://' in urls:
        urls = f'pipe:aws s3 cp --quiet {urls} -'
        print(f'S3 paths: {urls}')

    def webdataset_generator():
        bytes_np_mapper = (lambda data: np.frombuffer(data, dtype=np.uint8),)*len(extensions)
        dataset_instance = (wds.WebDataset(urls, shardshuffle=True)
                            .to_tuple(*extensions)
                            .map_tuple(*bytes_np_mapper)
                        )

        for sample in dataset_instance:
            yield sample

    dataset = webdataset_generator

    # Adding the buffered shuffling
    if random_shuffle:
        dataset = buffered_shuffle(dataset, initial_fill, seed)

    # Adding the batch padding
    if pad_last_batch:
        dataset = last_batch_padding(dataset, max_batch_size)

    # Collecting the data into batches (possibly undefull)
    # Handled by a custom function only when `silent_cycle` is False
    if cycle != "quiet":
        dataset = collect_batches(dataset, max_batch_size)

    # Prefetching the data
    if read_ahead:
        dataset=list(dataset())

    return fn.external_source(
        source=dataset,
        num_outputs=len(extensions),
        batch=(cycle != "quiet"), # If `cycle` is "quiet" then batching is handled by the external source
        cycle=cycle,
    )

def decode_augment(img, seed=0):
    img = fn.decoders.image(img)
    img = fn.jitter(img.gpu(), seed=seed)
    img = fn.resize(img, size=(224, 224))
    return img

# Below we define the sample webdataset pipeline with our `external_source`-based loader, 
# that just chains the previously defined reader and augmentation function together.
@dali.pipeline_def(batch_size=batch_size, num_threads=num_workers, device_id=0)
def webdataset_pipeline(
    urls,
    random_shuffle=False, 
    initial_fill=512,
    seed=0,
    pad_last_batch=False,
    read_ahead=False,
    cycle="quiet",
):

    images, labels = read_webdataset(urls=urls, 
                                 extensions=("jpg", "cls"),
                                 random_shuffle=random_shuffle,
                                 initial_fill=initial_fill,
                                 seed=seed,
                                 pad_last_batch=pad_last_batch,
                                 read_ahead=read_ahead,
                                 cycle=cycle,
                                )

    # PyTorch expects labels as INT64
    labels = fn.cast(labels, dtype=types.INT64)

    return decode_augment(images, seed=seed), labels

# PyTorch iterator

urls = './shards/train-{000000..000019}.tar'

pipeline = webdataset_pipeline(
    urls=urls,   # Paths for the sharded dataset
    random_shuffle=True, # Random buffered shuffling on
    pad_last_batch=False, # Last batch is filled to the full size
    read_ahead=False,
    cycle="quiet")     # All the data is preloaded into the memory

pipes = [pipeline]
for pipe in pipes:
    pipe.build()

dataset_size = 10222 // batch_size * batch_size
dali_iter = DALIClassificationIterator(pipes, last_batch_policy=LastBatchPolicy.DROP, auto_reset=True, size=dataset_size)

num_epochs = 3
for n in range(num_epochs):
    for data in (pbar := tqdm(dali_iter, leave=True)):
        pbar.set_description(f'{type(data)}')    

Which for me results in:

/home/ubuntu/anaconda3/envs/py38/lib/python3.8/site-packages/nvidia/dali/plugin/base_iterator.py:163: Warning: Please set `reader_name` and don't set last_batch_padded and size manually whenever possible. This may lead, in some situations, to missing some samples or returning duplicated ones. Check the Sharding section of the documentation for more details.
  _iterator_deprecation_warning()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [1], in <module>
    252     pipe.build()
    254 dataset_size = 10222 // batch_size * batch_size
--> 255 dali_iter = DALIClassificationIterator(pipes, last_batch_policy=LastBatchPolicy.DROP, auto_reset=True, size=dataset_size)
    257 num_epochs = 3
    258 for n in range(num_epochs):

File ~/anaconda3/envs/py38/lib/python3.8/site-packages/nvidia/dali/plugin/pytorch.py:367, in DALIClassificationIterator.__init__(self, pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded, last_batch_policy, prepare_first_batch)
    357 def __init__(self,
    358              pipelines,
    359              size=-1,
   (...)
    365              last_batch_policy=LastBatchPolicy.FILL,
    366              prepare_first_batch=True):
--> 367     super(DALIClassificationIterator, self).__init__(pipelines, ["data", "label"],
    368                                                      size,
    369                                                      reader_name=reader_name,
    370                                                      auto_reset=auto_reset,
    371                                                      fill_last_batch=fill_last_batch,
    372                                                      dynamic_shape=dynamic_shape,
    373                                                      last_batch_padded=last_batch_padded,
    374                                                      last_batch_policy=last_batch_policy,
    375                                                      prepare_first_batch=prepare_first_batch)

File ~/anaconda3/envs/py38/lib/python3.8/site-packages/nvidia/dali/plugin/pytorch.py:179, in DALIGenericIterator.__init__(self, pipelines, output_map, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded, last_batch_policy, prepare_first_batch)
    177 if self._prepare_first_batch:
    178     try:
--> 179         self._first_batch = DALIGenericIterator.__next__(self)
    180     except StopIteration:
    181         assert False, "It seems that there is no data in the pipeline. This may happen if `last_batch_policy` is set to PARTIAL and the requested batch size is greater than the shard size."

File ~/anaconda3/envs/py38/lib/python3.8/site-packages/nvidia/dali/plugin/pytorch.py:205, in DALIGenericIterator.__next__(self)
    203 category_shapes = dict()
    204 for category, out in category_outputs.items():
--> 205     category_tensors[category] = out.as_tensor()
    206     category_shapes[category] = category_tensors[category].shape()
    208 category_torch_type = dict()

RuntimeError: [/opt/dali/dali/pipeline/data/tensor_list.h:581] Assert on "this->IsDenseTensor()" failed: All tensors in the input TensorList must have the same shape and be densely packed.
austinmw commented 2 years ago

Ah, it looks like it was the labels, not the images. It works with batch_size > 1 if I add label = label.to_bytes(1, 'big'). Sorry and thanks!

JanuszL commented 2 years ago

I was about to write this. WebDataset encodes classes as strings, so the 10 value becomes "10" string. In this example DALI converts it to uint8 numpy array - bytes_np_mapper = (lambda data: np.frombuffer(data, dtype=np.uint8),)*len(extensions). What you can do in this particular example is to change the mapper function to bytes_np_mapper = (lambda data: np.frombuffer(data, dtype=np.uint8), lambda data: np.array([int(data)], dtype=np.int64)) so you can avoid labels = fn.cast(labels, dtype=types.INT64). Another option is to use experimental numba function:

from nvidia.dali.plugin.numba.fn.experimental import numba_function

def label_setup(outs, _):
    for i in range(len(outs)):
        for sample_idx in range(len(outs[i])):
            outs[i][sample_idx][0] = 1
def label_convert(out0, in0):
    val = 0
    for i in range(in0.shape[0]):
        val = (in0[i] - 48) + val * 10
    out0[0] = val

    labels = numba_function(labels, run_fn=label_convert, setup_fn=label_setup,
                                 out_types=[types.INT64], in_types=[types.UINT8],
                                 outs_ndim=[1], ins_ndim=[1])

This way you should be able to use the DALI build-in webdataset reader:

from nvidia.dali.plugin.numba.fn.experimental import numba_function

def label_setup(outs, _):
    for i in range(len(outs)):
        for sample_idx in range(len(outs[i])):
            outs[i][sample_idx][0] = 1
def label_convert(out0, in0):
    val = 0
    for i in range(in0.shape[0]):
        val = (in0[i] - 48) + val * 10
    out0[0] = val

@dali.pipeline_def(batch_size=batch_size, num_threads=num_workers, device_id=0)
def webdataset_pipeline(
    urls,
    random_shuffle=False,
    initial_fill=512,
    seed=0,
    pad_last_batch=False,
    read_ahead=False,
    cycle="quiet",
):

    images, labels = fn.readers.webdataset(paths =urls,
                                 ext=["jpg", "cls"],
                                 random_shuffle=random_shuffle,
                                 initial_fill=initial_fill,
                                 seed=seed,
                                 pad_last_batch=pad_last_batch
                                )
    labels = numba_function(labels, run_fn=label_convert, setup_fn=label_setup,
                                 out_types=[types.INT64], in_types=[types.UINT8],
                                 outs_ndim=[1], ins_ndim=[1])

    return decode_augment(images, seed=seed), labels

urls = [f'./shards/train-0000{i:02}.tar' for i in range(0, 20)]
austinmw commented 2 years ago

Thanks that's great!