libffcv / ffcv

FFCV: Fast Forward Computer Vision (and other ML workloads!)
https://ffcv.io
Apache License 2.0
2.8k stars 180 forks source link

Unsupported image type #232

Closed sa-cloud closed 1 year ago

sa-cloud commented 1 year ago

I tried to convert the torchvision dataset datasets.Places365 to ffcv format and got an exception:

Traceback (most recent call last):
  File "/u/jlerner/.conda/envs/ffcv/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/u/jlerner/.conda/envs/ffcv/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/u/jlerner/.conda/envs/ffcv/lib/python3.8/site-packages/ffcv/writer.py", line 112, in worker_job_indexed_dataset
    handle_sample(sample, dest_ix, field_names, metadata, allocator, fields)
  File "/u/jlerner/.conda/envs/ffcv/lib/python3.8/site-packages/ffcv/writer.py", line 50, in handle_sample
    field.encode(destination, field_value, allocator.malloc)
  File "/u/jlerner/.conda/envs/ffcv/lib/python3.8/site-packages/ffcv/fields/rgb_image.py", line 325, in encode
    raise TypeError(f"Unsupported image type {type(image)}")
TypeError: Unsupported image type <class 'torch.Tensor'>

I ran the following lines:

from ffcv.writer import DatasetWriter
from ffcv.fields import RGBImageField, IntField
path = '/places365_small/ffcv' 
writer = DatasetWriter(path, {'image': RGBImageField(),'label': IntField()})

from torchvision import transforms,datasets
transform_train = transforms.Compose([
...             transforms.RandomResizedCrop(224),
...             transforms.RandomHorizontalFlip(),
...             transforms.ToTensor(),
...             transforms.Normalize(mean=[0.485, 0.456, 0.406],
...                                              std=[0.229, 0.224, 0.225])
...         ])

ds = datasets.Places365(root='/places365_small', small=True, split = 'train-standard', download=False, transform=transform_train)
writer.from_indexed_dataset(ds)
sa-cloud commented 1 year ago

It works when not passing in the transformation

briteroses commented 1 year ago

Sorry for seeing this late! The RGBImageField you want to use in the ffcv writer requires a PIL Image to be passed in, not a torch tensor. Your original dataset keeps the images in PIL Image form but your torchvision pipeline includes a conversion to tensor, hence why the write only works without the torchvision transform. For your task, try shortening the torchvision transform to only including RandomResizedCrop and RandomHorizontalFlip. If you still need to normalize the image, you can normalize during ffcv loading by including the NormalizeImage operation in your image pipeline during loading:

https://docs.ffcv.io/_modules/ffcv/transforms/normalize.html#NormalizeImage

Furthermore, if you want to do all of the torchvision transforms in ffcv instead, RandomResizedCrop and RandomHorizontalFlip are actually both available in ffcv:

https://docs.ffcv.io/_modules/ffcv/transforms/random_resized_crop.html#RandomResizedCrop

https://docs.ffcv.io/_modules/ffcv/transforms/flip.html#RandomHorizontalFlip