libffcv / ffcv

FFCV: Fast Forward Computer Vision (and other ML workloads!)
https://ffcv.io
Apache License 2.0
2.82k stars 178 forks source link

Torchvision transformations (inherit module) breaks loader #134

Closed njwfish closed 2 years ago

njwfish commented 2 years ago

Using torchvision transformations I see the following issue:

Exception in thread Thread-6:
Traceback (most recent call last):
  File "/home/fishman/miniconda3/envs/ffcv/lib/python3.9/threading.py", line 973, in _bootstrap_inner
    self.run()
  File "/home/fishman/miniconda3/envs/ffcv/lib/python3.9/site-packages/ffcv/loader/epoch_iterator.py", line 79, in run
    result = self.run_pipeline(b_ix, ixes, slot, events[slot])
  File "/home/fishman/miniconda3/envs/ffcv/lib/python3.9/site-packages/ffcv/loader/epoch_iterator.py", line 133, in run_pipeline
    result = code(*args)
  File "/home/fishman/miniconda3/envs/ffcv/lib/python3.9/site-packages/numba/core/dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/home/fishman/miniconda3/envs/ffcv/lib/python3.9/site-packages/numba/core/dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'self': Cannot determine Numba type of <class 'ffcv.transforms.module.ModuleWrapper'>

File "/homes/fishman/miniconda3/envs/ffcv/lib/python3.9/site-packages/ffcv/transforms/module.py", line 25:
        def apply_module(inp, _):
            res = self.module(inp)
            ^

During: resolving callee type: type(CPUDispatcher(<function ModuleWrapper.generate_code.<locals>.apply_module at 0x7fad53753d30>))
During: typing of call at  (2)

During: resolving callee type: type(CPUDispatcher(<function ModuleWrapper.generate_code.<locals>.apply_module at 0x7fad53753d30>))
During: typing of call at  (2)

File "/data/ziz/not-backed-up/fishman/data-augmentation", line 2:
<source missing, REPL/exec in use?>

I get this issue using exactly the CIFAR10 example on a single GPU but adding torchvision.transforms.ColorJitter(.4,.4,.4) to the loader. I am running on Fedora, with CUDA 11.2 on a GTX 1080. I get similar issues for other torchvision transforms. I installed using the conda environment as specified. Native ffcv transforms work as expected.

GuillaumeLeclerc commented 2 years ago

Hi @njwfish

What version are you using ? Can I see the complete pipeline ?

PS: Torchvision augmentations are really slow we do not recommend using them if you care about performance at all. They will in most cases be slower with FFCV than pytorch DataLoader. We mostly support them to allow people to experiment but that should be it

njwfish commented 2 years ago

This is the pipeline code, I am get this issue using both the version that installs from pip and building from the main git branch here.

CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]

DUPLICATES = 2
BATCH_SIZE = 512 // DUPLICATES

loaders = {}
for name in ['train', 'test']:
    label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), ToDevice('cuda:0'), Squeeze()]
    image_pipeline: List[Operation] = [SimpleRGBImageDecoder()]

    # Add image transforms and normalization
    if name == 'train':
        image_pipeline.extend([
            RandomHorizontalFlip(),
            torchvision.transforms.ColorJitter(.4,.4,.4),
            RandomTranslate(padding=2),
            Cutout(8, tuple(map(int, CIFAR_MEAN))), # Note Cutout is done before normalization.
        ])
    image_pipeline.extend([
        ToTensor(),
        ToDevice('cuda:0', non_blocking=True),
        ToTorchImage(),
        Convert(ch.float16),
        torchvision.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    # Create loaders
    loaders[name] = Loader(f'/{data_dir}/cifar_{name}.beton',
                            batch_size=BATCH_SIZE,
                            num_workers=8,
                            order=OrderOption.RANDOM,
                            drop_last=(name == 'train'),
                            pipelines={'image': image_pipeline,
                                       'label': label_pipeline})

This is exactly from the CIFAR10 notebook with the addition of a torchvision transformation.

andrewilyas commented 2 years ago

Hi @njwfish ! I think the problem might be that your torchvision transforms are happening before the ToTensor, so they're acting on NumPy arrays. Try moving them to after the ToTorchImage() transformation and see if that fixes your issue! As @GuillaumeLeclerc mentioned though, torchvision augmentations are likely significantly slower than FFCV counterparts.

GuillaumeLeclerc commented 2 years ago

It seems that @andrewilyas is correct. If it doesn't work feel free to re-open the issue.