libffcv / ffcv

FFCV: Fast Forward Computer Vision (and other ML workloads!)
https://ffcv.io
Apache License 2.0
2.8k stars 180 forks source link

ffcv transforms ToTensor is not normalizing in range [0, 1] #270

Closed SerezD closed 1 year ago

SerezD commented 1 year ago

I have noticed that setting up a pipeline like:

image_pipeline = [CenterCropRGBImageDecoder(image_size, ratio=ratio), ToTensor(), ToTorchImage()]

results in batches of torch.uint8 image tensors in range [0 255].
I cannot find any way to have batches of images as torch.float32 and in range [0, 1], like is done in torchvision.transforms.ToTensor (docs here).

As a solution, I have written a custom Transform like the following:

from ffcv.pipeline.operation import Operation
from ffcv.pipeline.allocation_query import AllocationQuery
from ffcv.pipeline.state import State
from dataclasses import replace

from typing import Tuple, Optional, Callable

import numpy as np

class DivideImage255(Operation):

    def __init__(self):
        super().__init__()

    def generate_code(self) -> Callable:
        def divide(image, dst):
            dst = image.astype('float32') / 255.
            return dst

        divide.is_parallel = True

        return divide

    def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:

        return replace(previous_state, dtype=np.float32), None

This in practice converts images to float and divides by 255.
Is it the only way to obtain the desired result, or am I missing something?

Is it possible to have a ToTensor() transform that operates like the torch counterpart (merging the ToTensor, ToTorchImage, and DivideImage255 operations) in some future release?

andrewilyas commented 1 year ago

Hi @SerezD ! Check out examples/cifar/train_cifar.py -- basically, we never actually bring the images into [0, 1] and instead just normalize them directly with scaled normalization parameters (after converting them to float32 with the Convert augmentation):

            Convert(ch.float16),
            torchvision.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),

Feel free to re-open the issue if this doesn't help!