libffcv / ffcv

FFCV: Fast Forward Computer Vision (and other ML workloads!)
https://ffcv.io
Apache License 2.0
2.8k stars 180 forks source link

Using torchvision.transforms #247

Closed ym820 closed 1 year ago

ym820 commented 1 year ago

Hi, I am new to using ffcv as the dataloader and I wondered how I could apply functions in torchvision.transforms (e.g. ColorJitter or Grayscale) in the pipeline for data augmentation? I got errors like the below and am not sure how to solve them. Thanks!

My code:

def get_dataloader():
    loaders = {}
    for name in ['train', 'test']:
        fname = f'cifar_{name}.beton'
        label_pipeline = [IntDecoder(), ToTensor(), ToDevice('cuda:0'), Squeeze()]
        image_pipeline = [SimpleRGBImageDecoder()]

        # Add image transforms and normalization
        if name == 'train':
            image_pipeline.extend([
            RandomTranslate(padding=4),
            RandomHorizontalFlip(),
            T.ColorJitter(brightness=.5, hue=.3),
            T.Grayscale(),
            ])
        image_pipeline.extend([
            ToTensor(),
            ToDevice('cuda:0', non_blocking=True),
            ToTorchImage(),
            Convert(torch.float32),
        ])

        # Create loaders
        loaders[name] = Loader(fname,
                                batch_size=batch_size,
                                num_workers=16,
                                order=OrderOption.RANDOM,
                                drop_last=(name == 'train'),
                                pipelines={'image': image_pipeline,
                                           'label': label_pipeline})
    return loaders

The error message:

Exception in thread Thread-26:
Traceback (most recent call last):
  File "/opt/conda/envs/ffcv/lib/python3.9/threading.py", line 980, in _bootstrap_inner
    self.run()
  File "/opt/conda/envs/ffcv/lib/python3.9/site-packages/ffcv/loader/epoch_iterator.py", line 79, in run
    result = self.run_pipeline(b_ix, ixes, slot, events[slot])
  File "/opt/conda/envs/ffcv/lib/python3.9/site-packages/ffcv/loader/epoch_iterator.py", line 133, in run_pipeline
    result = code(*args)
  File "/opt/conda/envs/ffcv/lib/python3.9/site-packages/numba/core/dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/opt/conda/envs/ffcv/lib/python3.9/site-packages/numba/core/dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'self': Cannot determine Numba type of <class 'ffcv.transforms.module.ModuleWrapper'>

File "../opt/conda/envs/ffcv/lib/python3.9/site-packages/ffcv/transforms/module.py", line 25:
        def apply_module(inp, _):
            res = self.module(inp)
            ^

During: resolving callee type: type(CPUDispatcher(<function ModuleWrapper.generate_code.<locals>.apply_module at 0x7f13d0f8cc10>))
During: typing of call at  (2)

During: resolving callee type: type(CPUDispatcher(<function ModuleWrapper.generate_code.<locals>.apply_module at 0x7f13d0f8cc10>))
During: typing of call at  (2)

File "/workspace", line 2:
<source missing, REPL/exec in use?>
ym820 commented 1 year ago

Fixed. It seems like the order of the transforms matters

MBaltz commented 1 year ago

Hello @akai820 Can you tell me more about how you solve this?

ym820 commented 1 year ago

Hello @akai820 Can you tell me more about how you solve this?

From the example they provided, they put the transforms from pytorch at the end of the pipeline (https://docs.ffcv.io/ffcv_examples/cifar10.html), so I just tried to put these two T.ColorJitter(brightness=.5, hue=.3), T.Grayscale() at the end and it worked.

MBaltz commented 1 year ago

All right @akai820 I had this problem too: <source missing, REPL/exec in use?>

But in my case I was using a NDArrayField field to encode a numpy array with dtype: float16, and when I decoded this with NDArrayDecoder it didn't works.

So, to solve my problem, i just consider the dtype of the numpy array as float32 instead of float16, and it works!

briteroses commented 1 year ago

I know this is closed, but just so what happened here is fully clear to you:

The image loader, before any pipeline operations, returns numpy arrays in BGR color channel format. That's why the vast majority of examples and applications will use ToTensor() -> ToTorchImage() in the pipeline (ToTorchImage, among other things, performs a BGR -> RGB conversion). If you're using a torchvision transform which expects an RGB format tensor, you'd have to place it after ToTensor() and ToTorchImage(). But maybe in the future you'll want to use some custom transform that operates on a numpy BGR array, or a BGR tensor; then you would put that transform between ToTensor and ToTorchImage, or before ToTensor, respectively.