pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.2k stars 6.95k forks source link

Batch support for Transform? #157

Closed arunmallya closed 7 years ago

arunmallya commented 7 years ago

Any plans for updating Transform to support batch inputs instead of just single images? This is useful for applying transforms outside of a DataLoader (which does it on one image at a time).

fmassa commented 7 years ago

I don't think there are any plans on extending transforms to work on batched images. Indeed, I think transforms are supposed to be applied only in Datasets, so only single instances are required. Another point is that implementing batched transforms efficiently would require dedicated implementations, and would also raise the question of wether or not it would be interesting to have them on GPUs as well.

alykhantejani commented 7 years ago

Closing this for now as there currently are no plans to extend transforms to work on batched images.

Coolnesss commented 6 years ago

Just to follow up on this, right now to apply a transformation after getting a batch from DataLoader, I have to iterate over the batch and transform each tensor back to a PIL image, after which I do any additional transformations, and convert it back to tensor again. It's doable but it's fairly slow (unless I'm doing something wrong).

If you're open to a PR on this, I'd be happy to help if you can give me some pointers.

alykhantejani commented 6 years ago

@Coolnesss usually you do the transformations at the Dataset level. The DataLoader has many processes that read from the Dataset which effectively does your transformations in parallel.

Perhaps you can share some details of what your goal is and we can see if it falls outside of the current paradigm

Coolnesss commented 6 years ago

Thank you for your reply @alykhantejani !

I'm trying to create derivative datasets of e.g MNIST, by applying some category of random transformations on each set. Currently, I'm doing something like

d_transforms = [
    transforms.RandomHorizontalFlip(),
    # Some other transforms...
]
loaders = []
for i in range(len(d_transforms)):
    dataset = datasets.MNIST('./data', 
            train=train, 
            download=True, 
            transform=d_transforms[i]
    loaders.append(
        DataLoader(dataset, 
            shuffle=True, 
            pin_memory=True, 
            num_workers=1)
        )

Here, I get the desired outcome of having multiple DataLoaders that each provide samples from the transformed datasets. However, this is really slow, presumably because each worker tries to access the same files stored in ./data, and they can be accessed by one worker at a time (?). After profiling my program, nearly all of the time is spent on calls like

x, y = next(iter(train_loaders[i]))

I can think of two ways to solve this

  1. Apply transformations after getting the batch from the loader - but this requires batched transformations, otherwise it's slow
  2. Make n copies of MNIST on disk and let the workers each have their own copy, e.g dataset = datasets.MNIST('./data1', ...) etc.

Sorry for the lengthy post, and thanks for your help.

alykhantejani commented 6 years ago

@Coolnesss would this work for you:

class MultiTransformDataset(Dataset):
    def __init__(self, dataset, transforms):
        self.dataset = datset
        self.transforms = transforms

    def __get_item__(self, idx):
         input, target = self.dataset[idx]
         return tuple(t(input) for t in self.transforms) + (target, )
Coolnesss commented 6 years ago

Thanks for the workaround @alykhantejani

It's a much nicer solution, and somewhat faster too. Unfortunately it's still not as fast as I had hoped, perhaps the transforms themselves just take too much time. In any case, thanks for your help.

alykhantejani commented 6 years ago

@Coolnesss np. Let me know if you have any other questions

bermanmaxim commented 6 years ago

Note that you can also design a custom collate function that does the necessary transformations on your batch after collating it, e.g.

def get_collate(batch_transform=None):
    def mycollate(batch):
        collated = torch.utils.data.dataloader.default_collate(batch)
        if batch_transform is not None:
            collated = batch_transform(collated)
        return collated
    return mycollate

I find this strategy useful to add information in the batch (such as batch statistics, or complementary images in the dataset), and making the workers do the necessary computation.

hukkai commented 4 years ago

Hello, I am doing video tasks where each video is 32 frames of images. Then I need to resize and crop the 32 images by loops. A batch operation may be helpful or (faster?).

GabPrato commented 4 years ago

If this can help anyone, I implemented a few batch Transforms: https://github.com/pratogab/batch-transforms

shivam13juna commented 1 year ago

Common, let's have an official implementation of batch transforms, it's 2023 already!!

AnthonyArmour commented 1 year ago

This would be great in the case of online batch inference. Currently looking for a solution to my current use case.

chenzhike110 commented 1 year ago

torchvision.transforms.Lambda may help