Closed jbueltemeier closed 2 years ago
What about adding a
import torch.utils.data
class SkipSmallIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, dataset, min_size):
self.dataset = dataset
self.min_size = min_size
def __iter__(self):
for sample in self.dataset:
if sample >= self.min_size:
yield sample
dataset = SkipSmallIterableDataset([6, 5, 4, 3, 2, 1, 4], min_size=4)
for sample in torch.utils.data.DataLoader(dataset, batch_size=1):
print(sample)
tensor([6])
tensor([5])
tensor([4])
tensor([4])
I have also thought about this, but have not found a simple solution in combination with a BatchSampler
, which in the case of a small image yields the next index.
In addition, in at least one case, the dataset must be iterated over more than once (50k iteration).
Could you compile a minimal example with all the parts that we need so we have a clearer picture. If my implementation does not work, I'm happy to adapt yours. It is complicated though. Hence, I'm looking for easier alternatives.
I am aware that my solution is not optimal. I just don't know how to change the __getitem__
method and the idx
.
Here is a minimal example where currently a RuntimeError
is issued if the input is too small:
import torch.utils.data
from pystiche.image import transforms
from pystiche_papers.data.utils import FiniteCycleBatchSampler
class SmallImageError(transforms.Transform):
def __init__(self, min_size):
super().__init__()
self.min_size = min_size
def forward(self, input_image: torch.Tensor) -> torch.Tensor:
if input_image >= self.min_size:
return input_image
msg = "Small image error!"
raise RuntimeError(msg)
class SkipSmallIterableDataset(torch.utils.data.Dataset):
def __init__(self, dataset, min_size, transform):
self.dataset = dataset
self.min_size = min_size
self.transform = transform
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int):
data = self.dataset[idx]
if self.transform:
data = self.transform(data)
return data
min_size = 2
transform = transforms.ComposedTransform(SmallImageError(min_size=min_size))
dataset = SkipSmallIterableDataset([6, 5, 4, 3, 2, 1, 4], min_size=min_size, transform=transform)
batch_sampler = FiniteCycleBatchSampler(dataset, num_batches=10, batch_size=1)
for sample in torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler):
print(sample)
Current output (with reduced Error message):
tensor([6])
tensor([5])
tensor([4])
tensor([3])
tensor([2])
RuntimeError: Small image error!
The output that is needed:
tensor([6])
tensor([5])
tensor([4])
tensor([3])
tensor([2])
tensor([4])
tensor([6])
tensor([5])
tensor([4])
The needed output is achieved with the SkipDataLoader
in this PR.
Skipping is only possible with iterable datasets. Unfortunately, they don't work with samplers or batch samplers. Thus, I think it would be the easiest to implement a small custom dataset that integrates the functionality we have in FiniteCycleBatchSampler
:
import itertools
import torch.utils.data
from torch import nn
class SmallImageError(nn.Module):
def __init__(self, min_size):
super().__init__()
self.min_size = min_size
def forward(self, input_image: torch.Tensor) -> torch.Tensor:
if input_image >= self.min_size:
return input_image
msg = "Small image error!"
raise RuntimeError(msg)
class Dataset(torch.utils.data.IterableDataset):
def __init__(self, dataset, min_size, transform, num_samples):
self.dataset = dataset
self.min_size = min_size
self.transform = transform
self.num_samples = num_samples
def __iter__(self):
dataset = itertools.cycle(self.dataset)
num_samples = 0
while num_samples < self.num_samples:
sample = next(dataset)
if sample >= self.min_size:
yield transform(sample)
num_samples += 1
min_size = 2
transform = SmallImageError(min_size=min_size)
dataset = Dataset(
[6, 5, 4, 3, 2, 1, 4],
min_size=min_size,
transform=transform,
# num_batches * batch_size
num_samples=10 * 1,
)
for idx, sample in enumerate(torch.utils.data.DataLoader(dataset), 1):
print(idx, sample)
1 tensor([6])
2 tensor([5])
3 tensor([4])
4 tensor([3])
5 tensor([2])
6 tensor([4])
7 tensor([6])
8 tensor([5])
9 tensor([4])
10 tensor([3])
Merging #286 (50afab4) into main (0d8179d) will decrease coverage by
0.5%
. The diff coverage is67.8%
.
@@ Coverage Diff @@
## main #286 +/- ##
=======================================
- Coverage 97.0% 96.5% -0.6%
=======================================
Files 39 39
Lines 1639 1655 +16
=======================================
+ Hits 1591 1598 +7
- Misses 48 57 +9
Flag | Coverage Δ | |
---|---|---|
unit | 96.5% <67.8%> (-0.6%) |
:arrow_down: |
Flags with carried forward coverage won't be shown. Click here to find out more.
Impacted Files | Coverage Δ | |
---|---|---|
pystiche_papers/ulyanov_et_al_2016/_utils.py | 100.0% <ø> (ø) |
|
pystiche_papers/ulyanov_et_al_2016/_data.py | 88.2% <67.8%> (-8.3%) |
:arrow_down: |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update 0d8179d...50afab4. Read the comment docs.
Oops, it seems you have pushed in the time I reviewed. You can probably ignore some of my comments.
I learned today, that IterableDataset
needs special handling if used with more than one worker: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset. Although we set num_workers=0
by default, we cannot guarantee the users will do the same. Do you want me to fix this here, or should we add an issue and fix it later?
That's interesting, thanks for the comment. Since this PR is only for this dataset, I think it makes sense to integrate it here. So feel free to do it here.
I also learned something today (see this Issue):
itertools.cycle
attempts to save all outputs in order to re-cycle through them
This increases the memory with each iteration until the script crashes. Currently, however, itertools.cycle
is required for the Dataset
to cycle through the images several times.
One solution would be from #77:
def cycle(data_io):
while True:
for x in data_io:
yield x
self.data_samples = iter(cycle(self.dataset))
Or do you see a better solution @pmeier?
Thanks a lot @jbueltemeier!
A current problem with the training in
ulyanov_et_al_2016
is that it causes an error if the images in the dataset are too small. The error is caused by trying to cropping to a larger image size. As mentioned in #248, the authors of the original publication have solved this by simply using a different image (see L91-L100).This PR solves this problem by using a
SkipDataLoader
that uses the next image in case of an error. This is a minimal solution, which should be enough for this replication. Or do you have an idea with which this could be done better @pmeier?As a side note, this affects maybe 100-200 images out of the nearly 40k images in the MSCOCO dataset.