Lightning-Universe / lightning-flash

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains
https://lightning-flash.readthedocs.io
Apache License 2.0
1.74k stars 213 forks source link

[Feature request] Compatibility with iterable-style datasets #1237

Open austinmw opened 2 years ago

austinmw commented 2 years ago

šŸš€ Feature

I'd like to be able to train iterable-style datasets instead of just map-style datasets. (a map-style dataset in PyTorch has __getitem__ and __len__, whereas iterable-style datasets only have __iter__)

Motivation

Many image datasets in commercial use cases are very large, and therefore require iterable-style rather than map-style. (Users may create custom iterable datasets, or use torchdata, webdataset, DALI, etc.)

Pitch

Vision tasks seem to require iterating over the entire dataset and building records prior to training (e.g. ObjectDetectionData). This does not make sense as a required step for large datasets. Say for example you want to compare models on a dataset of 10M images. Requiring iterating over this dataset for potentially several hours before training starts seems like an unnecessary and costly step. Users should be able to begin training online and have each sample from an iterable dataset provide the necessary information.

Lack of this capability in my opinion prevents adoption of vision tasks in this library on large scale image training in commercial settings.

Additional context

lightning-bolts object detectors seem to support this style of dataset already.

Links: https://pytorch.org/blog/efficient-pytorch-io-library-for-large-datasets-many-files-many-gpus/ https://github.com/pytorch/data

austinmw commented 2 years ago

For a little more context, I'll paste below example code for a custom LightningDataModule. This datamodule uses DALI and webdataset format. It works fine using pl_bolts object detectors without modification to the dataloading and with minimal modification to training_step. I'd prefer to use flash detectors over bolts detectors since there's a larger selection though.

import os
import glob
import pickle
import numpy as np
import cv2
import torch
import torchvision.transforms as T
import warnings
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import pytorch_lightning as pl
from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin
from nvidia import dali
from nvidia.dali import pipeline_def, types, fn
from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy

# Read label map (dict, like 1: person, 2: car, etc.)
with open('coco_idx2label', 'rb') as f:
    idx2label = pickle.load(f)

# Get urls (.tar file paths)
train_dali_urls = sorted(glob.glob(os.path.join(os.getcwd(), 'coco_shards_dali', 'train*')))
val_dali_urls = sorted(glob.glob(os.path.join(os.getcwd(), 'coco_shards_dali', 'val*')))
# For example:
# ['/home/ubuntu/data/coco_shards_dali/train-000000.tar',
#  '/home/ubuntu/data/coco_shards_dali/train-000001.tar',
#  ...
#  '/home/ubuntu/data/coco_shards_dali/train-000031.tar']

class DataModuleClass(pl.LightningDataModule):
    def __init__(self, 
                 idx2label, 
                 train_urls,
                 val_urls=None,
                 batch_size=16,
                 num_workers=os.cpu_count() // torch.cuda.device_count(),
                 mean=[103.530, 116.280, 123.675],
                 std=[57.375, 57.120, 58.395],
                 seed=42):

        #Define required parameters here
        self.idx2label = idx2label        
        self.train_urls = train_urls
        self.val_urls = val_urls
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.mean = mean
        self.std = std
        self.seed = seed

        self.prepare_data_per_node = False
        self._log_hyperparams = False

    def prepare_data(self):
        # Define steps that should be done
        # on only one GPU, like getting data.
        pass

    def setup(self, stage=None):
        # Define steps that should be done on 
        # every GPU, like splitting data, applying
        # transform etc.        

        # Create train and val dataloaders

        if hasattr(self.trainer, 'local_rank'):
            device_id = self.trainer.local_rank
            shard_id = self.trainer.global_rank
            num_shards = self.trainer.world_size            
        else:
            warnings.warn('DataModule setup called before trainer init, using default device_id, shard_id, num_shards')
            device_id = 0
            shard_id = 0
            num_shards = 1

        train_pipe = self._wds_pipeline(urls=self.train_urls, 
                                        batch_size=self.batch_size,
                                        num_threads=self.num_workers,
                                        device='gpu',
                                        device_id=device_id, 
                                        shard_id=shard_id,
                                        num_shards=num_shards,
                                        random_shuffle=True,
                                        seed=self.seed,
                                        train=True)

        class LightningWrapper(DALIGenericIterator):
            def __init__(self, *kargs, **kvargs):
                super().__init__(*kargs, **kvargs)

            def __next__(self):
                item = super().__next__()
                images = item[0]['images']
                bboxes = item[0]['bboxes']
                labels = item[0]['labels']
                return {'images': images, 'bboxes': bboxes, 'labels': labels}

        self.train_loader = LightningWrapper(
            train_pipe,
            ['images', 'bboxes', 'labels'],
            reader_name='Reader',
            last_batch_policy=LastBatchPolicy.PARTIAL,
            auto_reset=True)

        if self.val_urls:
            val_pipe = self._wds_pipeline(urls=self.val_urls, 
                                            batch_size=self.batch_size,
                                            num_threads=self.num_workers,
                                            device='gpu',
                                            device_id=device_id, 
                                            shard_id=shard_id,
                                            num_shards=num_shards,
                                            random_shuffle=False,
                                            seed=self.seed,
                                            train=False)

            self.val_loader = LightningWrapper(
                val_pipe,
                ['images', 'bboxes', 'labels'],
                reader_name='Reader',
                last_batch_policy=LastBatchPolicy.PARTIAL,
                auto_reset=True)

    def train_dataloader(self):
        # Return DataLoader for Training Data here
        return self.train_loader

    def val_dataloader(self):
        # Return DataLoader for Validation Data here
        if self.val_urls is not None:
            return self.val_loader

    def _decode_augment(self, images, bboxes, labels, device, seed=0, fp16=True, train=True):
        bboxes = fn.reshape(bboxes, shape=[64,4])

        # Adjust boxes due to rounding issues with xyWH format    
        bboxes = dali.math.clamp(bboxes, lo=0.0, hi=1.0)
        xy = bboxes[:,0:2]
        wh = bboxes[:,2:4]
        wh -= dali.math.max(0.0, (xy+wh) - 1.0)
        bboxes = fn.cat(xy,wh, axis=1)

        if train:
            aspect_ratio = [0.5, 2.0]
            thresholds=[0, 0.1, 0.3, 0.5, 0.7, 0.9]
            scaling=[0.3, 1.0]
        else:
            aspect_ratio = [1.0, 1.0]            
            thresholds= [0.9]
            scaling = [1.0, 1.0]

        #input_shape = fn.slice(fn.cast(fn.peek_image_shape(images), dtype=types.INT32), 0, 2, axes=[0])
        crop_begin, crop_size, bboxes, labels = fn.random_bbox_crop(bboxes, labels,
                                                                    device='cpu',
                                                                    aspect_ratio=aspect_ratio,
                                                                    thresholds=thresholds,
                                                                    scaling=scaling,
                                                                    bbox_layout='xyWH',
                                                                    allow_no_crop=True,
                                                                    num_attempts=50)

        #images = fn.decoders.image(images, device='mixed', output_type=types.RGB)
        images = fn.decoders.image_slice(images, crop_begin, crop_size, 
                                         device='mixed' if device == 'gpu' else 'cpu',
                                         output_type=types.RGB)

        if train:
            flip_coin = fn.random.coin_flip(probability=0.5)
        else:
            flip_coin = fn.random.coin_flip(probability=0.0)

        images = fn.resize(images, resize_x=416, resize_y=416,
                           min_filter=types.DALIInterpType.INTERP_TRIANGULAR)

        if train:
            saturation = fn.random.uniform(range=[0.5, 1.5])
            contrast = fn.random.uniform(range=[0.5, 1.5])
            brightness = fn.random.uniform(range=[0.875, 1.125])
            hue = fn.random.uniform(range=[-0.5, 0.5])            

            images = fn.hsv(images, dtype=types.FLOAT, hue=hue, saturation=saturation)  # use float to avoid clipping and
                                                                 # quantizing the intermediate result
            images = fn.brightness_contrast(images,
                                            contrast_center = 128,  # input is in float, but in 0..255 range
                                            dtype = types.UINT8,
                                            brightness = brightness,
                                            contrast = contrast)

        dtype = types.FLOAT16 if fp16 else types.FLOAT

        bboxes = fn.bb_flip(bboxes, ltrb=False, horizontal=flip_coin)

        images = fn.crop_mirror_normalize(images,
                                          crop=(416, 416),
                                          mean=self.mean,
                                          std=self.std,
                                          mirror=flip_coin,
                                          dtype=dtype,
                                          output_layout='CHW',
                                          pad_output=False)
        # Un-normalize
        bboxes *= 416        

        # Pad
        bboxes = fn.pad(bboxes, fill_value=0.0, axes=(0,), shape=(64,))
        labels = fn.pad(labels, fill_value=0.0, axes=(0,), shape=(64,))

        if device == 'gpu':
            labels = labels.gpu()
            bboxes = bboxes.gpu()

        # Cast to int
        bboxes = fn.cast(bboxes, dtype=types.INT64)
        labels = fn.cast(labels, dtype=types.INT64)  

        return images, bboxes, labels

    @pipeline_def
    def _wds_pipeline(self, 
                      urls,
                      device,
                      shard_id=0,
                      num_shards=1,
                      random_shuffle=True,
                      train=True):
        images, bboxes, labels = fn.readers.webdataset(
            paths=urls,
            shard_id=shard_id, 
            num_shards=num_shards, 
            random_shuffle=random_shuffle,
            #device='mixed' if device == 'gpu' else 'cpu',
            ext=['jpg', 'bboxes', 'labels'],
            missing_component_behavior='error',
            dtypes=[types.UINT8, types.FLOAT, types.INT32],
            seed=self.seed,
            name='Reader')

        return self._decode_augment(images, bboxes=bboxes, labels=labels, device=device, seed=self.seed, train=train)

# intantiate the datamodule
datamodule = DataModuleClass(
    idx2label, 
    train_urls=train_dali_urls,
    val_urls=val_dali_urls, 
    batch_size=16,
)

# If you need information from the dataset to build your model, then run prepare_data() and setup() manually (Lightning ensures the method runs on the correct devices).
datamodule.prepare_data()
datamodule.setup(stage='fit')
ethanwharris commented 2 years ago

Hi @austinmw Thanks for your request! This is a current limitation of certain tasks in Flash where they cannot be directly used with your own datamodule because the model needs to provide the collate function for the data. IceVision models are slightly more complex again in that they need to provide the dataloader in full. I think it should be possible for us to find a workaround there as this would be a great use-case to support šŸ˜ƒ

stale[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

austinmw commented 2 years ago

Not stale

ethanwharris commented 2 years ago

Hey @austinmw just to give you an update. We have resolved in the framework most of the issues that are needed to support your use-case and now just need to document it properly and ship it in our upcoming 0.8 release. Can't give an exact timeline, but aiming for weeks rather than months. I'll come back here when I can give an updated code snippet to make this work :smiley:

austinmw commented 2 years ago

Awesome news, can't wait to see, thanks!