pystiche / papers

Reference implementation and replication of prominent NST papers
BSD 3-Clause "New" or "Revised" License
4 stars 1 forks source link

Skipping DataLoader in ulyanov_et_al_2016 #286

Closed jbueltemeier closed 2 years ago

jbueltemeier commented 2 years ago

A current problem with the training in ulyanov_et_al_2016 is that it causes an error if the images in the dataset are too small. The error is caused by trying to cropping to a larger image size. As mentioned in #248, the authors of the original publication have solved this by simply using a different image (see L91-L100).

This PR solves this problem by using a SkipDataLoader that uses the next image in case of an error. This is a minimal solution, which should be enough for this replication. Or do you have an idea with which this could be done better @pmeier?

As a side note, this affects maybe 100-200 images out of the nearly 40k images in the MSCOCO dataset.

pmeier commented 2 years ago

What about adding a

import torch.utils.data

class SkipSmallIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, dataset, min_size):
        self.dataset = dataset
        self.min_size = min_size

    def __iter__(self):
        for sample in self.dataset:
            if sample >= self.min_size:
                yield sample

dataset = SkipSmallIterableDataset([6, 5, 4, 3, 2, 1, 4], min_size=4)

for sample in torch.utils.data.DataLoader(dataset, batch_size=1):
    print(sample)
tensor([6])
tensor([5])
tensor([4])
tensor([4])
jbueltemeier commented 2 years ago

I have also thought about this, but have not found a simple solution in combination with a BatchSampler, which in the case of a small image yields the next index.

In addition, in at least one case, the dataset must be iterated over more than once (50k iteration).

pmeier commented 2 years ago

Could you compile a minimal example with all the parts that we need so we have a clearer picture. If my implementation does not work, I'm happy to adapt yours. It is complicated though. Hence, I'm looking for easier alternatives.

jbueltemeier commented 2 years ago

I am aware that my solution is not optimal. I just don't know how to change the __getitem__ method and the idx.

Here is a minimal example where currently a RuntimeError is issued if the input is too small:

import torch.utils.data
from pystiche.image import transforms
from pystiche_papers.data.utils import FiniteCycleBatchSampler

class SmallImageError(transforms.Transform):
    def __init__(self, min_size):
        super().__init__()
        self.min_size = min_size

    def forward(self, input_image: torch.Tensor) -> torch.Tensor:
        if input_image >= self.min_size:
            return input_image
        msg = "Small image error!"
        raise RuntimeError(msg)

class SkipSmallIterableDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, min_size, transform):
        self.dataset = dataset
        self.min_size = min_size
        self.transform = transform

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int):
        data = self.dataset[idx]
        if self.transform:
            data = self.transform(data)
        return data

min_size = 2
transform = transforms.ComposedTransform(SmallImageError(min_size=min_size))
dataset = SkipSmallIterableDataset([6, 5, 4, 3, 2, 1, 4], min_size=min_size, transform=transform)
batch_sampler = FiniteCycleBatchSampler(dataset, num_batches=10, batch_size=1)
for sample in torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler):
    print(sample)

Current output (with reduced Error message):

tensor([6])
tensor([5])
tensor([4])
tensor([3])
tensor([2])
RuntimeError: Small image error!

The output that is needed:

tensor([6])
tensor([5])
tensor([4])
tensor([3])
tensor([2])
tensor([4])
tensor([6])
tensor([5])
tensor([4])

The needed output is achieved with the SkipDataLoader in this PR.

pmeier commented 2 years ago

Skipping is only possible with iterable datasets. Unfortunately, they don't work with samplers or batch samplers. Thus, I think it would be the easiest to implement a small custom dataset that integrates the functionality we have in FiniteCycleBatchSampler:

import itertools

import torch.utils.data
from torch import nn

class SmallImageError(nn.Module):
    def __init__(self, min_size):
        super().__init__()
        self.min_size = min_size

    def forward(self, input_image: torch.Tensor) -> torch.Tensor:
        if input_image >= self.min_size:
            return input_image
        msg = "Small image error!"
        raise RuntimeError(msg)

class Dataset(torch.utils.data.IterableDataset):
    def __init__(self, dataset, min_size, transform, num_samples):
        self.dataset = dataset
        self.min_size = min_size
        self.transform = transform
        self.num_samples = num_samples

    def __iter__(self):
        dataset = itertools.cycle(self.dataset)
        num_samples = 0
        while num_samples < self.num_samples:
            sample = next(dataset)
            if sample >= self.min_size:
                yield transform(sample)
                num_samples += 1

min_size = 2
transform = SmallImageError(min_size=min_size)
dataset = Dataset(
    [6, 5, 4, 3, 2, 1, 4],
    min_size=min_size,
    transform=transform,
    # num_batches * batch_size
    num_samples=10 * 1,
)

for idx, sample in enumerate(torch.utils.data.DataLoader(dataset), 1):
    print(idx, sample)
1 tensor([6])
2 tensor([5])
3 tensor([4])
4 tensor([3])
5 tensor([2])
6 tensor([4])
7 tensor([6])
8 tensor([5])
9 tensor([4])
10 tensor([3])
codecov-commenter commented 2 years ago

Codecov Report

Merging #286 (50afab4) into main (0d8179d) will decrease coverage by 0.5%. The diff coverage is 67.8%.

@@           Coverage Diff           @@
##            main    #286     +/-   ##
=======================================
- Coverage   97.0%   96.5%   -0.6%     
=======================================
  Files         39      39             
  Lines       1639    1655     +16     
=======================================
+ Hits        1591    1598      +7     
- Misses        48      57      +9     
Flag Coverage Δ
unit 96.5% <67.8%> (-0.6%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
pystiche_papers/ulyanov_et_al_2016/_utils.py 100.0% <ø> (ø)
pystiche_papers/ulyanov_et_al_2016/_data.py 88.2% <67.8%> (-8.3%) :arrow_down:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 0d8179d...50afab4. Read the comment docs.

pmeier commented 2 years ago

Oops, it seems you have pushed in the time I reviewed. You can probably ignore some of my comments.

pmeier commented 2 years ago

I learned today, that IterableDataset needs special handling if used with more than one worker: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset. Although we set num_workers=0 by default, we cannot guarantee the users will do the same. Do you want me to fix this here, or should we add an issue and fix it later?

jbueltemeier commented 2 years ago

That's interesting, thanks for the comment. Since this PR is only for this dataset, I think it makes sense to integrate it here. So feel free to do it here.

jbueltemeier commented 2 years ago

I also learned something today (see this Issue):

itertools.cycle attempts to save all outputs in order to re-cycle through them

This increases the memory with each iteration until the script crashes. Currently, however, itertools.cycle is required for the Dataset to cycle through the images several times.

One solution would be from #77:

        def cycle(data_io):
            while True:
                for x in data_io:
                    yield x
        self.data_samples = iter(cycle(self.dataset)) 

Or do you see a better solution @pmeier?

pmeier commented 2 years ago

Thanks a lot @jbueltemeier!