pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.11k stars 6.94k forks source link

BUG: ColorJitter in torchvision.transforms #2563

Closed CristianManta closed 4 years ago

CristianManta commented 4 years ago

🐛 Bug

ColorJitter is supposed to be called on an image of type PIL or Tensor, but can only be called on images of type PIL.

To Reproduce

Steps to reproduce the behavior:

  1. Load the data by specifically composing a ToTensor() transformation followed by a ColorJitter() one.
  2. Create a DataLoader using that dataset
  3. Try to loop through the loader

Code example that reproduces this bug:

from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
root = 'path/to/cifar/data'

color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
ds_train = CIFAR10(root, download=True, train=True, transform=transforms.Compose([transforms.ToTensor(), color_jitter]))

train_loader = DataLoader(ds_train, batch_size=128, num_workers=4, drop_last=True, shuffle=False)

for (images, label) in train_loader:
    print("There is no Bug!")

Error message:

Traceback (most recent call last):
  File "train.py", line 11, in <module>
    for (images, label) in train_loader:
  File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 400, in __next__
    data = self._next_data()
  File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1032, in _next_data
    return self._process_data(data)
  File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1058, in _process_data
    data.reraise()
  File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/_utils.py", line 420, in reraise
    raise self.exc_type(msg)
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
    data = fetcher.fetch(index)
  File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torchvision/datasets/cifar.py", line 120, in __getitem__
    img = self.transform(img)
  File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 54, in __call__
    img = t(img)
  File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 1092, in forward
    img = F.adjust_hue(img, hue_factor)
  File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 728, in adjust_hue
    raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
TypeError: img should be PIL Image. Got <class 'torch.Tensor'>

Expected behavior

Should be able to run without errors.

Environment

PyTorch version: 1.7.0.dev20200807
Is debug build: No
CUDA used to build PyTorch: 10.2

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
GPU 2: GeForce RTX 2080 Ti
GPU 3: GeForce RTX 2080 Ti

Nvidia driver version: 440.33.01
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5

Versions of relevant libraries:
[pip3] numpy==1.19.1
[pip3] torch==1.7.0.dev20200807
[pip3] torchfile==0.1.0
[pip3] torchnet==0.0.4
[pip3] torchvision==0.8.0.dev20200807
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.2.89              hfd86e86_1  
[conda] mkl                       2020.1                      217  
[conda] mkl-service               2.3.0            py37he904b0f_0  
[conda] mkl_fft                   1.1.0            py37h23d657b_0  
[conda] mkl_random                1.1.1            py37h0573a6f_0  
[conda] numpy                     1.19.1           py37hbc911f0_0  
[conda] numpy-base                1.19.1           py37hfa32c7d_0  
[conda] pytorch                   1.7.0.dev20200807 py3.7_cuda10.2.89_cudnn7.6.5_0    pytorch-nightly
[conda] torchfile                 0.1.0                    pypi_0    pypi
[conda] torchnet                  0.0.4                    pypi_0    pypi
[conda] torchvision               0.8.0.dev20200807      py37_cu102    pytorch-nightly

Additional context

By looking into the code of transforms/transforms.py and transforms/functional.py of the master branch of this repo, the description of forward at line 1064 (in class ColorJitter) of transforms.py says:

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Input image.
        Returns:
            PIL Image or Tensor: Color jittered image.
        """

which contradicts the fact that ColorJitter was placed in the "Transforms on PIL Image" category in the official documentation at https://pytorch.org/docs/stable/torchvision/transforms.html. Then forward proceeds to call, on lines 1077, 1082, 1087 and 1092, the following functions in this order: F.adjust_brightness(img, brightness_factor), F.adjust_contrast(img, contrast_factor), F.adjust_saturation(img, saturation_factor) and F.adjust_hue(img, hue_factor).

Now, if we look into transforms/functional.py, the first three functions handle both Tensor and PIL type cases. For example, at line 675, we have (in the case of adjust_brightness):

    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_brightness(img, brightness_factor)

    return F_t.adjust_brightness(img, brightness_factor)

However, it seems to have been forgotten to update the last one as well. Indeed, at line 742, in adjust_hue, we have:

    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_hue(img, hue_factor)

    raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

It would also be nice if other transformations, like RandomGrayscale, could support types of both Tensor and PIL as inputs (with the types of their outputs being consistent).

Pull Request

I opened a pull request (#2566) to attempt to address the reported bug.

vfdev-5 commented 4 years ago

@CristianManta yes, this is known and the progress on unifying inputs of transforms can be found #2292 , F.adjust_hue on Tensor is not yet supported even if it is coded in F_t.adjust_hue. We have to make sure that the output is almost the same as for PIL, thus few tests are necessary before merging it.

CristianManta commented 4 years ago

Thanks for mentioning #2292 , I completely missed it.