Lightning-AI / litdata

Transform datasets at scale. Optimize datasets for fast AI model training.
Apache License 2.0
308 stars 36 forks source link

If data samples contain a Python list with a variable number of elements, type inference will fail #260

Open senarvi opened 1 month ago

senarvi commented 1 month ago

🐛 Bug

When optimizing a dataset, BinaryWriter.serialize() will first flatten the sample dictionary and infer the type of each element. Then it will assume that all data samples have the same types and the same number of elements. If a list contains a variable number of elements, the type information will be incorrect for subsequent data samples. This is not documented and can cause some confusion.

To Reproduce

When storing the data in NumPy arrays, there's no problem, because each array is considered as one element in the flattened type list. So this works:

import numpy as np
from PIL import Image
import litdata as ld

def random_images(index):
    fake_images = Image.fromarray(np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8))
    num_labels = np.random.randint(10)
    fake_labels = np.random.randint(10, size=num_labels)

    # use any key:value pairs
    data = {"index": index, "image": fake_images, "class": fake_labels}

    return data

if __name__ == "__main__":
    # the optimize function outputs data in an optimized format (chunked, binerized, etc...)
    ld.optimize(
        fn=random_images,                   # the function applied to each input
        inputs=list(range(1000)),           # the inputs to the function (here it's a list of numbers)
        output_dir="my_optimized_dataset",  # optimized data is stored here
        num_workers=4,                      # The number of workers on the same machine
        chunk_bytes="64MB"                  # size of each chunk
    )

When storing the data in a Python list, the type of each list element is inferred separately. If a list contains a variable number of elements, the type information of one sample is not useful for other samples. Now, if we add an element that has a different type (say, a string) after a variable-length list, the function will crash:

def random_images(index):
    fake_images = Image.fromarray(np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8))
    num_labels = np.random.randint(10)
    fake_labels = np.random.randint(10, size=num_labels).tolist()
    name = "name"

    # use any key:value pairs
    data = {"index": index, "image": fake_images, "class": fake_labels, "name": name}

    return data

The error:

Rank 0 inferred the following `['int', 'pil', 'int', 'int', 'int', 'str']` data format.                                                                                                                                                                         | 0/1000 [00:00<?, ?it/s]
Rank 1 inferred the following `['int', 'pil', 'int', 'int', 'int', 'str']` data format.
Worker 1 is done.
Traceback (most recent call last):
  File ".../example.py", line 20, in <module>
    ld.optimize(
  File ".../lib/python3.12/site-packages/litdata/processing/functions.py", line 432, in optimize
    data_processor.run(
  File ".../lib/python3.12/site-packages/litdata/processing/data_processor.py", line 1055, in run
    self._exit_on_error(error)
  File ".../lib/python3.12/site-packages/litdata/processing/data_processor.py", line 1119, in _exit_on_error
    raise RuntimeError(f"We found the following error {error}.")
RuntimeError: We found the following error Traceback (most recent call last):
  File ".../lib/python3.12/site-packages/litdata/processing/data_processor.py", line 665, in _handle_data_chunk_recipe
    chunk_filepath = self.cache._add_item(self._index_counter, item_data_or_generator)
  File ".../lib/python3.12/site-packages/litdata/streaming/cache.py", line 134, in _add_item
    return self._writer.add_item(index, data)
  File ".../lib/python3.12/site-packages/litdata/streaming/writer.py", line 305, in add_item
    data, dim = self.serialize(items)
  File ".../lib/python3.12/site-packages/litdata/streaming/writer.py", line 179, in serialize
    self._serialize_with_data_format(flattened, sizes, data, self._data_format)
  File ".../lib/python3.12/site-packages/litdata/streaming/writer.py", line 210, in _serialize_with_data_format
    serialized_item, _ = serializer.serialize(element)
  File ".../lib/python3.12/site-packages/litdata/streaming/serializers.py", line 347, in serialize
    return obj.encode("utf-8"), None
AttributeError: 'int' object has no attribute 'encode'

Expected behavior

It should be documented that every data sample must have the same elements and every list must be the same size.

I wonder if caching the data types is necessary. Afterall, the optimize() call doesn't have to be that fast. If the data types were instead inferred for every sample separately, it would be possible to use variable-length lists.

Would it make sense to at least check in BinaryWriter.serialize() that self._data_format has the same number of elements as the data sample?

Environment

github-actions[bot] commented 1 month ago

Hi! thanks for your contribution!, great first issue!

tchaton commented 1 month ago

Hey @senarvi,

Would you mind making a PR to add a note in the README about this ?

Additionally, we could consider adding placeholder to tell LitData to not de-composed the pytree too deep. Like a List or Dict object that would be serialized as its own.

senarvi commented 1 month ago

Here's a PR with some documentation changes: https://github.com/Lightning-AI/litdata/pull/264

I didn't quite understand how the placeholder would work.

tchaton commented 1 month ago

Hey @senarvi,

I was thinking something like this.

from litdata.types import List, Dict

def fn(...):

    return {"list": List([...])}

This way, the pytree won't be decomposed and the list would be encoded as single element.

But the reality is that any placeholder would do the job as pytree won't know how to decomposed it further down.

Example:

class Placeholder:

    def __init__(self, object):
        self.object = object
senarvi commented 1 month ago

@tchaton something like that could be an elegant solution. I guess then you need to write a serializer for the Placeholder class. Was your idea that all the elements of List would have a specific data type?

tchaton commented 1 month ago

Yes, we could have something smart for it if we want to. Would you be interested into contributing such feature ?

senarvi commented 1 month ago

I don't have the possibility at the moment. I was able to easily work around the problem by using NumPy arrays instead of Python lists. Now I'll have to move on and get something working to see if this framework is suitable for us.

tchaton commented 1 month ago

Hey @senarvi. Thanks for trying LitData, let me know how it goes ;) Use the latest version too. I will make a new release tomorrow.