NVIDIA / DALI

A GPU-accelerated library containing highly optimized building blocks and an execution engine for data processing to accelerate deep learning training and inference applications.
https://docs.nvidia.com/deeplearning/dali/user-guide/docs/index.html
Apache License 2.0
5.15k stars 621 forks source link

ExternalSource and fn.readers.file #3811

Open qsunyuan opened 2 years ago

qsunyuan commented 2 years ago
  1. CodeBase1: I need to use the weighted random sampler for ImageNet1000, so, i checke this issue. link. Then, I followed this link to get my own codes, including Sampler, ExternalSourcePipeline and DALIGenericIterator.

  2. CodeBase2: Also, I check this ImageNet1000-resnet link.

However, the difference in efficiency of loading data is surprisingly large!!!

I'm not sure. The possible reason may be the difference between fn.readers.file in link and np.fromfile in link

for _ in range(self.batch_size):
    jpeg_filename, label = self.files[self.i % self.n].split(' ')
    batch.append(np.fromfile(self.images_dir + jpeg_filename, dtype = np.uint8))  # we can use numpy
    labels.append(torch.tensor([int(label)], dtype = torch.uint8)) # or PyTorch's native tensors
    self.i += 1

How should I achieve an efficiency similar to CodeBase2 with weighted random sampler.

JanuszL commented 2 years ago

Hi @qsunyuan,

Please check the most recent improvements in the external source operator - especially parallel and prefetch_queue_depth. It should improve the performance by moving the loading itself to the side thread/process so it is no longer on the critical path.

qsunyuan commented 2 years ago
import os
import random
import torch
import numpy as np

import nvidia.dali.types as types
import nvidia.dali.fn as fn

from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator 
from nvidia.dali.plugin.pytorch import LastBatchPolicy

from tqdm import tqdm
from math import ceil

class SamplerIterator(object):
    def __init__(self, 
                 data, targets, 
                 batch_size,
                 flags):

        self.data = data 
        self.targets = targets
        self.batch_size = batch_size
        self.size = len(self.data)

        self.weights1 = 1
        self.weights2 = 2
        self.flags = flags

    def __iter__(self):
        self.i = 0

        weights = torch.where(self.flags, self.weights1, self.weights2).float()
        index = torch.multinomial(weights, self.size, replacement=True)

        self.data = self.data[index]
        self.targets = self.targets[index]

        return self

    def __next__(self):
        batch = []
        labels = [] 

        if self.i >= self.size:
            self.__iter__()
            raise StopIteration

        for _ in range(self.batch_size):
            data = self.data[self.i % self.size]
            target = self.targets[self.i % self.size]
            batch.append(np.fromfile(data, dtype=np.uint8)) 
            labels.append(target.astype(np.uint8)) 
            self.i += 1

            if self.i == self.size:
                break

        return (batch, labels)

    def __len__(self):
        return ceil(self.size / self.batch_size)

    next = __next__

device_memory_padding = 211025920  
host_memory_padding = 140544512  
preallocate_width_hint = 5980  
preallocate_height_hint = 6430 

def create_dali_pipeline(batch_size, num_threads, device_id, sampler_iterator, is_training=True):
    pipe = Pipeline(batch_size, num_threads, device_id) 
    with pipe:
        jpegs, labels = fn.external_source(source=sampler_iterator, 
                                                         num_outputs=2, 
                                                         dtype=types.UINT8)
        if is_training:
            images = fn.decoders.image_random_crop(jpegs,
                                               device="mixed", output_type=types.RGB,
                                               device_memory_padding=device_memory_padding,
                                               host_memory_padding=host_memory_padding,
                                               preallocate_width_hint=preallocate_width_hint,
                                               preallocate_height_hint=preallocate_height_hint,
                                               random_aspect_ratio=[0.75, 1.333333],
                                               random_area=[0.08, 1.0],
                                               num_attempts=100)
            images = fn.resize(images, device="gpu", 
                               resize_x=224, resize_y=224, 
                               interp_type=types.INTERP_TRIANGULAR)
            mirror = fn.random.coin_flip(probability=0.5)
        else:
            images = fn.decoders.image(jpegs,
                                        device="mixed",
                                        output_type=types.RGB)
            images = fn.resize(images,
                               device="gpu",
                               resize_x=256, resize_y=256, 
                               mode="not_smaller",
                               interp_type=types.INTERP_TRIANGULAR)
            mirror = False

        images = fn.crop_mirror_normalize(images.gpu(),
                                          dtype=types.FLOAT,
                                          output_layout="CHW",
                                          crop=(224, 224), 
                                          mean=[0.485 * 255,0.456 * 255,0.406 * 255],
                                          std=[0.229 * 255,0.224 * 255,0.225 * 255],
                                          mirror=mirror)
        labels = labels.gpu()
        pipe.set_outputs(images, labels)

    return pipe

class ClassificationIterator(DALIGenericIterator):
    def __init__(self,
                 sampler_iterator,
                 pipelines,
                 size=-1, 
                 reader_name=None, 
                 auto_reset=False, 
                 fill_last_batch=None, 
                 dynamic_shape=False, 
                 last_batch_padded=True, 
                 last_batch_policy=LastBatchPolicy.PARTIAL,
                 prepare_first_batch=True):
        super(ClassificationIterator, self).__init__(pipelines, ["data", "label"],
                                                         size,
                                                         reader_name=reader_name,
                                                         auto_reset=auto_reset,
                                                         fill_last_batch=fill_last_batch,
                                                         dynamic_shape=dynamic_shape,
                                                         last_batch_padded=last_batch_padded,
                                                         last_batch_policy=last_batch_policy,
                                                         prepare_first_batch=prepare_first_batch)
        self.sampler_iterator = sampler_iterator

    def __len__(self):
        return len(self.sampler_iterator)

def get_dali_dataloader(data, targets, 
                        batch_size,
                        device_id=0,
                        num_threads=16,
                        is_training=True
                        ):
    sampler_iterator = SamplerIterator(data, 
                                       targets,
                                       batch_size)

    pipe = create_dali_pipeline(batch_size=batch_size, 
                                num_threads=num_threads, 
                                device_id=device_id,
                                sampler_iterator=sampler_iterator,
                                is_training=is_training)

    data_loader = ClassificationIterator(sampler_iterator, pipe)

    return data_loader
qsunyuan commented 2 years ago

Thx for your quick replay. This is my current codes.

Do you mean fn.readers.file enable parallel mode?

JanuszL commented 2 years ago

Hi,

I mean the external_source operator that has new options that should speed things up. Please check the example I have referred to in my previous message.

qsunyuan commented 2 years ago

I followed the reference.

  1. shuffle is too cumbersome.

My shuffle code is follows.

from sklearn.utils import shuffle

epoch_idx_seed = sample_info.epoch_idx 
shuffle_idx = shuffle(np.arange(len(self.files)), random_state=epoch_idx_seed)
  1. Without shuffle. The external_source with parallel seems slower with navie loop.
for data in tqdm(loader):
    pass
qsunyuan commented 2 years ago

Did I miss something?

THX in advance.

qsunyuan commented 2 years ago

Another point of confusion is: I use the DALI loader from this https://github.com/NVIDIA/DALI/blob/fcdbcd1f861ce862173f86005203026c072862c1/docs/examples/use_cases/pytorch/resnet50/main.py#L95, and the efficiency is very good.

qsunyuan commented 2 years ago

I modeified the shuffle code as follows:

    def __call__(self, sample_info):
        sample_idx = sample_info.idx_in_epoch
        epoch_idx_seed = sample_info.epoch_idx 

        if epoch_idx_seed != self.curr_epoch_idx_seed:
            shuffle_idx = shuffle(np.arange(len(self.files)), random_state=epoch_idx_seed)
            self.files = self.files[shuffle_idx]
            self.labels = self.labels[shuffle_idx]

            self.curr_epoch_idx_seed = sample_info.epoch_idx 

It's a lot faster, but still doesn't beat external_source without parallel.

JanuszL commented 2 years ago

Just recently we have updated the documentation providing more details regarding the external source operation principle. How and when the work can be split. Please check this particular section of the documentation. If the data source is stateful it is processed only by one process in parallel. If it is a stateless function it can be processed by multiple processes in parallel.