libffcv / ffcv

FFCV: Fast Forward Computer Vision (and other ML workloads!)
https://ffcv.io
Apache License 2.0
2.8k stars 180 forks source link

Fix RandomTranslate #267

Closed kristian-georgiev closed 1 year ago

kristian-georgiev commented 1 year ago

Currently RandomTranslate does not use the fill argument and instead puts random (garbage) values in the padding. The proposed changes fix this.

kristian-georgiev commented 1 year ago

As a minimal example (essentially trimming down the CIFAR-10 example):

from typing import List
from matplotlib import pyplot as plt

import torch as ch
import torchvision

from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import  RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze
from ffcv.writer import DatasetWriter

dataset = torchvision.datasets.CIFAR10('/tmp', train=True, download=True)

writer = DatasetWriter(f'/tmp/cifar_translate_test.beton', {
        'image': RGBImageField(),
        'label': IntField()
    })
writer.from_indexed_dataset(dataset)

CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]

label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), ToDevice('cuda:0'), Squeeze()]
image_pipeline: List[Operation] = [SimpleRGBImageDecoder(),
                                   RandomTranslate(padding=2, fill=tuple(map(int, CIFAR_MEAN))),
                                   ToTensor(),
                                   ToDevice('cuda:0', non_blocking=True),
                                   ToTorchImage(),
                                   Convert(ch.float16),
                                   torchvision.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
                                  ]

loader = Loader(f'/tmp/cifar_translate_test.beton',
                        batch_size=512,
                        num_workers=8,
                        order=OrderOption.SEQUENTIAL,
                        drop_last=True,
                        pipelines={'image': image_pipeline,
                                   'label': label_pipeline})

mu = ch.tensor(CIFAR_MEAN).reshape([-1, 1, 1])
sig = ch.tensor(CIFAR_STD).reshape([-1, 1, 1])
for epoch, (ims, labs) in enumerate(loader):
    z = (ims[0].cpu() * sig + mu).int()
    plt.imshow(z.permute([1, 2, 0]).numpy()); plt.show()
    break

produces things like

image

Instead, the expected behavior is that the padding is filled with a constant grey-ish color (the per-channel mean color in CIFAR-10).

andrewilyas commented 1 year ago

I think this is fixed by #184 too which has been merged into 1.0.0