pytorch / examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.
https://pytorch.org/examples
BSD 3-Clause "New" or "Revised" License
22.21k stars 9.51k forks source link

Please change dcgan to load truncated images. #835

Open cynthia-rempel opened 3 years ago

cynthia-rempel commented 3 years ago

When I was using the dcgan example, it had:

Traceback (most recent call last):
  File "dcgan.py", line 220, in <module>
    for i, data in enumerate(dataloader, 0):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 363, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 403, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.8/site-packages/torchvision/datasets/folder.py", line 137, in __getitem__
    sample = self.loader(path)
  File "/opt/conda/lib/python3.8/site-packages/torchvision/datasets/folder.py", line 173, in default_loader
    return pil_loader(path)
  File "/opt/conda/lib/python3.8/site-packages/torchvision/datasets/folder.py", line 156, in pil_loader
    return img.convert('RGB')
  File "/opt/conda/lib/python3.8/site-packages/PIL/Image.py", line 902, in convert
    self.load()
  File "/opt/conda/lib/python3.8/site-packages/PIL/ImageFile.py", line 255, in load
    raise OSError(
OSError: image file is truncated (150 bytes not processed)

So I added:

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

to examples/tree/master/dcgan/main.py and that fixed it Thanks!

jefflomax commented 3 years ago

It would also be really helpful is we could log the filename of the truncated image

msaroufim commented 2 years ago

@cynthia-rempel would you like to make a PR with this fix?

matttturnbull commented 4 months ago

Agreed, would be very useful!

I tried switching to this dataset: https://huggingface.co/datasets/cats_vs_dogs

@cynthia-rempel 's solution above addresses the truncated image issue, but I'm hitting some kind of snag during training, which causes values to turn to 'nan', after which (as you can imagine) training progress is broken.

Has anyone experienced similar / does anyone have any ideas on how to start debugging this issue?

Screenshot 2024-04-22 at 3 26 50 PM
matttturnbull commented 4 months ago

For what it's worth, I created an image checker as well (probably logging too much for this example project, but useful for the above debugging):

from pathlib import Path
from PIL import Image, ImageFile

# Load truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Check images are not corrupt and conform to rules
img_rules = {
    "suffix": ['jpg', 'jpeg', 'png'],
    "formats": ['JPEG', 'PNG'],
    "size": 1000,
    "modes": ['RGB']
}
img_counts = {
    "total": 0,
    "rejected": 0,
}

def is_image(path):
    try:
        img_counts["total"] += 1
        file = Path(path)
        # Check image is correct suffix
        suffix = f"{file.suffix.replace('.','')}"
        if suffix not in img_rules['suffix']:
            print(f"{path} >> ❌ Suffix")
            img_counts["rejected"] += 1
            return False
        # Open image and check each rule
        img = Image.open(path)
        rej = None
        if img.format not in img_rules['formats']:
            rej = f"Format: {img.format}"
        if img.size[0] > img_rules['size'] or img.size[1] > img_rules['size']:
            rej = f"Size: {img.size}"
        if img.mode not in img_rules['modes']:
            rej = f"Mode: {img.mode}"
        if rej:
            print(f"{path} >> ❌ {rej}")
            img_counts["rejected"] += 1
            return False
        return True
    except:
        return False

print(
    f"{img_counts['total']} Images | ❌ Rejected: {img_counts['rejected']} | ✅ Accepted: {img_counts['total'] - img_counts['rejected']}")

You can then enable the above checker function for every loaded image, by adding the is_valid_file param to datasets.ImageFolder:

# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize(
                                   (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ],),
                           is_valid_file=is_image
                           )