libffcv / ffcv

FFCV: Fast Forward Computer Vision (and other ML workloads!)
https://ffcv.io
Apache License 2.0
2.84k stars 178 forks source link

Compatibility with timm augmentation? #195

Closed ardasahiner closed 2 years ago

ardasahiner commented 2 years ago

Hi,

I was attempting to use FFCV with timm, using the fact thattorch.nn.Modules should be compatible with the pipelines argument of FFCV's Loader. However, I am getting some strange errors and would like some clarification on what is going wrong here.

Please see my simple reproducible implementation below. I use CIFAR100 images and use timm's create_transform function. While each transform is not an instance of nn.Module, I attempted to wrap it in a simple module with the CustomClass. However, I get the following issue as documented below.

Would you have any suggestions what is causing this issue, or any ideas for a simpler integration with timm? Any help is appreciated.

Error:

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'self': Cannot determine Numba type of <class 'ffcv.transforms.module.ModuleWrapper'>

File "../anaconda3/envs/ffcv/lib/python3.9/site-packages/ffcv/transforms/module.py", line 25:
        def apply_module(inp, _):
            res = self.module(inp)
            ^

During: resolving callee type: type(CPUDispatcher(<function ModuleWrapper.generate_code.<locals>.apply_module at 0x7fef91639670>))
During: typing of call at  (2)

During: resolving callee type: type(CPUDispatcher(<function ModuleWrapper.generate_code.<locals>.apply_module at 0x7fef91639670>))
During: typing of call at  (2)

Implementation:

import torch
import numpy as np
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform

from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, \
    RandomResizedCropRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, NormalizeImage, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze

class CustomClass(torch.nn.Module):
    def __init__(self, transform):
        super().__init__()
        self.transform = transform

    @staticmethod
    def get_params(img, scale, ratio):
        return self.transform.get_params(img, scale, ratio)

    def forward(self, img):
        return self.transform(img)

    def __repr__(self):
        return self.transform.__repr__()

is_train = True
imnet_mean, imnet_std = np.array(IMAGENET_DEFAULT_MEAN)*256, np.array(IMAGENET_DEFAULT_STD)*256

paths = {
    'train': 'cifar100_train.beton',
    'test': 'cifar100_test.beton'
}

to_module = create_transform(
                input_size=224,
                is_training=True,
                color_jitter=0.4,
                auto_augment='rand-m9-mstd0.5-inc1',
                interpolation='bicubic',
                re_prob=0.25,
                re_mode='pixel',
                re_count=False,
                mean = imnet_mean,
                std = imnet_std,
            )

module_list = []
for t in to_module.transforms:
    if isinstance(t, torch.nn.Module):
        module_list.append(t)
    else:
        t_new = CustomClass(t)
        module_list.append(t_new)

transform = torch.nn.Sequential(*module_list)

label_pipeline = [IntDecoder(), ToTensor(), Squeeze()]
image_pipeline = [SimpleRGBImageDecoder(), transform]

ordering =(OrderOption.QUASI_RANDOM) if is_train else OrderOption.SEQUENTIAL
dataset = Loader(paths['train'] if is_train else paths['test'], batch_size=10, num_workers=2,
                       order=ordering, drop_last=(is_train), os_cache=True, distributed=False,
                       pipelines={'image': image_pipeline, 'label': label_pipeline})

for i, (image, label) in enumerate(dataset):
    if i == 1:
        break
    print('loaded one image')
GuillaumeLeclerc commented 2 years ago

Hello @ardasahiner. Timm augmentations are not numba augmentations and therefore won't be able to benefit from FFCV. There is a #154 that might be of some help (maybe you can even contribute if you feel like it). Unfortunately, our team is too small to add support for all existing augmentations. We hope to get some help from the community for this.