crowsonkb / k-diffusion

Karras et al. (2022) diffusion models for PyTorch
MIT License
2.21k stars 371 forks source link

AttributeError in Oxford Flowers Demo: 'dict' object has no attribute 'convert' #93

Open mnslarcher opened 6 months ago

mnslarcher commented 6 months ago

Hi,

First, congratulations on this amazing repository; it's a great codebase for study.

I'm attempting to run the Oxford Flowers demo but am encountering an error. I'm not sure if it's just me or if the demo is currently non-functional (perhaps due to some changes on Hugging Face's end, or something else I'm unaware of):

    images = [transform(image.convert(mode)) for image in examples[image_key]]
AttributeError: 'dict' object has no attribute 'convert'

Here's the full traceback:

Traceback (most recent call last):
  File "/notebooks/k-diffusion/train.py", line 525, in <module>
    main()
  File "/notebooks/k-diffusion/train.py", line 435, in main
    for batch in tqdm(train_dl, smoothing=0.1, disable=not accelerator.is_main_process):
  File "/usr/local/lib/python3.9/dist-packages/tqdm/std.py", line 1195, in __iter__
    for obj in iterable:
  File "/usr/local/lib/python3.9/dist-packages/accelerate/data_loader.py", line 451, in __iter__
    current_batch = next(dataloader_iter)
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
    return self._process_data(data)
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
    data.reraise()
  File "/usr/local/lib/python3.9/dist-packages/torch/_utils.py", line 694, in reraise
    raise exception
AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.9/dist-packages/datasets/arrow_dataset.py", line 2165, in __getitem__
    return self._getitem(
  File "/usr/local/lib/python3.9/dist-packages/datasets/arrow_dataset.py", line 2150, in _getitem
    formatted_output = format_table(
  File "/usr/local/lib/python3.9/dist-packages/datasets/formatting/formatting.py", line 532, in format_table
    return formatter(pa_table, query_type=query_type)
  File "/usr/local/lib/python3.9/dist-packages/datasets/formatting/formatting.py", line 281, in __call__
    return self.format_row(pa_table)
  File "/usr/local/lib/python3.9/dist-packages/datasets/formatting/formatting.py", line 387, in format_row
    formatted_batch = self.format_batch(pa_table)
  File "/usr/local/lib/python3.9/dist-packages/datasets/formatting/formatting.py", line 418, in format_batch
    return self.transform(batch)
  File "/notebooks/k-diffusion/k_diffusion/utils.py", line 39, in hf_datasets_augs_helper
    images = [transform(image.convert(mode)) for image in examples[image_key]]
  File "/notebooks/k-diffusion/k_diffusion/utils.py", line 39, in <listcomp>
    images = [transform(image.convert(mode)) for image in examples[image_key]]
AttributeError: 'dict' object has no attribute 'convert'
mnslarcher commented 6 months ago

OK, I think I found the problem.

from datasets import load_dataset

dataset = load_dataset("nelorth/oxford-flowers")
image = dataset["train"][0]["image"]
print(image.keys())

gives you:

dict_keys(['bytes', 'path'])

This breaks hf_datasets_augs_helper that expect dataset["train"][0]["image"] to be a PIL Image:

def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
    """Apply passed in transforms for HuggingFace Datasets."""
    images = [transform(image.convert(mode)) for image in examples[image_key]]
    return {image_key: images}

A fix for this dataset would be:

import io
from PIL import Image

def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
    """Apply passed in transforms for HuggingFace Datasets."""
    images = [transform(Image.open(io.BytesIO(image["bytes"])).convert(mode)) for image in examples[image_key]]
    return {image_key: images}