Lightning-AI / litdata

Transform datasets at scale. Optimize datasets for fast AI model training.
Apache License 2.0
335 stars 39 forks source link

Time per sample grows as processed samples grows #119

Closed scritter closed 4 months ago

scritter commented 4 months ago

🐛 Bug

Time per sample grows as processed samples grows

To Reproduce

Steps to reproduce the behavior:

Follow the example provided using optimize. Increase the number of samples. Observe either total time grow nonlinearly with samples, or watch iteration time decrease as processed samples grows.

Expected behavior

I would naively expect the time per sample to be roughly unrelated to total samples.

Additional context

This is a bit of a blocker for large datasets. I tried to preprocess a text dataset composed of a couple hundred million samples. After a few minutes, time to completion looked to be around ~1 hour on the instance I was using. Come back an hour later and that has ballooned to ~15 hours. Watching the iterations/s I can see it gradually decreasing as more and more samples are processed.

When I tried an experiment measuring total time to dataset prep with increasing samples, I found approximately linear increase in time per sample as total samples grew.

This is a major deterrent to use for large datasets. Any guidance?

github-actions[bot] commented 4 months ago

Hi! thanks for your contribution!, great first issue!

tchaton commented 4 months ago

Hey @scritter. Any chance you could provide a reproducible script or available to get on a call ?

sritterginkgo commented 4 months ago

@tchaton thank you for taking the time to respond, and apologies for my delay.

Here's an example script for timing

import time
import numpy as np
import matplotlib.pyplot as plt

import sys
import io
from contextlib import redirect_stdout, redirect_stderr

from litdata import optimize

def random_sample(i):
    return {"sample": "A"*100}

def function_to_profile(n):
    """A simple function to profile; here we just sum the first n integers."""
    return optimize(fn=random_sample, inputs=list(range(n)), output_dir="/tmp/test-dataset", num_workers=1, chunk_bytes="64MB")

def profile_function(func, max_input_size, step_size, num_repeats):
    sizes = range(1000, max_input_size, step_size)
    times_per_size = []

    for size in sizes:
        with io.StringIO() as buf_out, io.StringIO() as buf_err, redirect_stdout(buf_out), redirect_stderr(buf_err):
            start_time = time.time()
            for _ in range(num_repeats):
                func(size)
            elapsed_time = time.time() - start_time
        average_time = elapsed_time / num_repeats
        times_per_size.append(average_time / size if size != 0 else 0)  # Time per operation

        print(f"completed {size}")
    return sizes, times_per_size

def plot_results(sizes, times_per_size):
    plt.figure(figsize=(10, 5))
    plt.plot(sizes, times_per_size, marker='o')
    plt.title('Function Execution Time per Sample vs Input Size')
    plt.xlabel('Input Size')
    plt.ylabel('Average Time per Sample (seconds)')
    plt.grid(True)
    plt.show()

# Parameters
max_input_size = 50000
step_size = 10000
num_repeats = 3

# Profiling and plotting
sizes, times_per_size = profile_function(function_to_profile, max_input_size, step_size, num_repeats)
plot_results(sizes, times_per_size)

The resulting plot for a single-worker setup: image

Here's also a brief time-lapse of the progressbar for a 32 worker setup on a 32 core instance with a large target of random samples.

image

And the plot of that relationship:

image

Welcome any feedback or issues that you see!

tchaton commented 4 months ago

Thanks @sritterginkgo, I will look into it.

sritterginkgo commented 4 months ago

Any updates here @tchaton ? Is there any other information that would be helpful for me to provide? Or separately, is this indeed a phenomenon that you and the team have observed?

tchaton commented 4 months ago

Hey @sritterginkgo, I haven't observed it myself. I am a bit packed right now, so I don't have time to look into it. I wondered if you would have some bandwidth to investigate ?

ouj commented 4 months ago

I have observed the same issue for some of my datasets.

In one case, over about 4 days, the training time grew from about 50 mins per epoch to almost 4 hours per epoch.

Screenshot 2024-05-21 at 7 35 53 PM

ouj commented 4 months ago

Screenshot 2024-05-21 at 7 38 55 PM

Another thing I noticed is that litdata, compared to the streaming dataset from MosiacML, underutilized the memory. The slow-down potentially coming from heavily relying on the on-disk cache?

tchaton commented 4 months ago

Hey @ouj, interesting. You have a slightly different issue. This issue was tracking increasing time during optimize, not in the StreamingDataset. So it would be better for you to open an another issue.

However, I have never noticed this behaviour with our internal testing. Do you believe you could provide a reproducible script with synthetic data ?

Are you re-creating a new dataloader for every epoch ?

ouj commented 4 months ago

@tchaton, yes. I misread this issue, which referred to the slowdown in preparing the data.

I will file a new issue, but I don't know if I can provide a good reproducible script because it doesn't seem to be reproducible across all the datasets, and it is expensive to repro (those training were on A100s with DDP for a couple of days).

Are you re-creating a new dataloader for every epoch ?

I am using LightningDataModule without doing anything specific. Do you mean setting reload_dataloaders_every_epoch=True on the Trainer? I can try that.

ouj commented 4 months ago

Filed a new issue: https://github.com/Lightning-AI/litdata/issues/138

sritterginkgo commented 4 months ago

I might find an opportunity to @tchaton . Is there a particular example that the team has used with synthetic data that I can use as a counter example to ensure the issue isn't something in my setup?

tchaton commented 4 months ago

I might find an opportunity to @tchaton . Is there a particular example that the team has used with synthetic data that I can use as a counter example to ensure the issue isn't something in my setup?

Hey @sritterginkgo I trust you and the issue. I am just sharing I haven't encountered it yet. Here is our preparation of TinyLLAMA: https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset?section=all&query=tinyLLAMA

CPU Utilization was at 100% for 6 hours.

image

sritterginkgo commented 4 months ago

Oh thank you! This is helpful.

I will add in my experiments CPU utilization hasn't been an issue, 100% essentially always across workers (past the initial setup). When I had tried to record timings of sub-steps over time I hadn't gotten down to to a bedrock operations. I might continue with those efforts to see if there's some datastructure/queue that seems to be holding things up.

tchaton commented 4 months ago

Hey @sritterginkgo. Thanks for the update. If you have time, we can schedule a call to pair debug it and hopefully, find the source.

sritterginkgo commented 4 months ago

@tchaton hello!

In the below I'm referring to having the function called by optimize be a generator rather than returning individual items, similar to your linked examples converting large datasets from collections of files.

I had an opportunity to try exploring this a little more. I believe that I found the issue, at least one of them. In your writer.py the BinaryWriter upon adding an item run the _should_write function which contains a loop over all serialized items currently in the cache. For reasonably-sized chunks this will mean that one might be performing billions/trillions of iterations within this loop. As I understand the logic, almost all of it is wasted computing the same things over and over. I made a modification to writer.py exploring a solution that keeps a running total in the writer and just checks the addition. I'll copy-paste the changes I made to the bottom.

The speedups I'm seeing are enormous, preprocessing a synthetic 100M sample collection (splits between generators of size 1e5) completed in 2 minutes on an 8 core machine. If you have an opportunity I would highly recommend looking into the suggestions, it may offer speedups for your users.

Let me know if I missed anything!

Example call:

total_samples = 1e8

def random_batch(x):
    #batch_size = int(1e3)
    for _ in range(int(1e5)):
        yield random_sample(0)

optimize(fn=random_batch, inputs=list(range(1000)), output_dir="/tmp/test-dataset", num_workers=8, chunk_bytes="64MB", fast_dev_run=False)
class BinaryWriter:
    def __init__():
        ....
        self._per_sample_num_bytes = 0
        self._per_sample_num_items = 0
    def add_item(self, index: int, items: Any) -> Optional[str]:
        # Track the minimum index provided to the writer
        # Serialize the items and store an Item object.
        if index in self._serialized_items:
            raise ValueError(f"The provided index {index} already exists in the cache.")

        data, dim = self.serialize(items)
        self._serialized_items[index] = Item(
            index=index,
            data=data,
            bytes=len(data),
            dim=dim,
        )
        if self._min_index is None:
            self._min_index = index
        if not self._should_write_per_item(index):
            return None
        filepath = os.path.join(self._cache_dir, self.get_chunk_filename())
        self.write_chunk()
        self._per_sample_num_bytes = self._serialized_items[index].bytes
        self._per_sample_num_items = self._serialized_items[index].dim if self._serialized_items[index].dim else 1
        self._min_index = index
        self._max_index = None
        return filepath

    def _should_write_per_item(self, index):
        """
        The original _should_write results in n passes per item which has huge slowdown as samples grows.
        """
        item = self._serialized_items.get(index, None)
        if item:
            self._per_sample_num_bytes += item.bytes
            self._per_sample_num_items += item.dim if item.dim else 1
            if (self._chunk_bytes and self._chunk_bytes < self._per_sample_num_bytes) or (
                    self._chunk_size and self._per_sample_num_items > self._chunk_size
                ):
                self._max_index = index
                return True
        return False
tchaton commented 4 months ago

Hey @sritterginkgo. That's awesome. Do you want to make a PR to fix it ?

Some context: The main reason why there were this loop in the first place was to enable users to return samples in an un-ordered way as the downloaders are asynchornous and yet manage to get samples properly ordered. However, this is less of a requirement today.

I think it would be good to keep the option but adding a new argument keep_ordered=False to the BinaryWriter and use your code by default, otherwise the current one.

sritterginkgo commented 4 months ago

Thanks for the feedback! Happy to propose a PR, are there any restrictions or can I just submit one?

Do you or the team have some document explaining more the logic of how samples would flow into this in the unordered context? I did indeed make some very particular assumptions about the order, but was curious if there are higher-level patterns that would still be maintained.

For example, as I understand the code it will work with an unordered input if that input comes from some contiguous collection. Otherwise it seems like this would have some odd failure modes. For example:

Take a chunk size of 3 and a non-ordered growing cache. There might be some intermediate state which is (extra spacing for visual clarity): [1,2, 4,5,6,7,8] The current code will find a minimum of index 1, but not write the cache until 3 is received.

Is that a correct intermediate? Is the assumption then that a write will always receive a contiguous set of indices? (That seems to be correct, but I'm still new looking through the codebase).

If so, there could be an alternative to what I proposed that does a check on add if index < current_min. If so it could reset the min and re-run the full loop logic. If not would just use the single-addition logic. That should offer support for the above case while having a higher-performance route for what might be now the standard context.

tchaton commented 4 months ago

Hey @sritterginkgo, go ahead with submitting a PR and adding me for review. Make sure to add some tests too ;)

Do you or the team have some document explaining more the logic of how samples would flow into this in the unordered context

I haven't written any. If you do local processing as you do, it would be ordered. However, litdata was mostly built to process dataset in the cloud. If the input contains file in the cloud to download, the downloaders within the workers are getting to download the file. The downloaders are then pushing the id to a queue when it is ready. This isn't guaranteed to be ordered. Here: https://github.com/Lightning-AI/litdata/blob/main/src/litdata/processing/data_processor.py#L180

This code was meant to maintain the same order between the inputs the user provide to the optimize function and the idx of the sample within the dataset.

If so, there could be an alternative to what I proposed that does a check on add if index < current_min

This seems like an interesting idea.

sritterginkgo commented 4 months ago

Given the current logic, how is the below protected for?

If we have some out of order receiving of samples, there may be some intermediate that excludes the edges of what the final chunk would be, such as:

The current logic will find 3, iterate up to 5, and if this were the max size allowed, write that section out. Then we would receive the remaining:

1, 2, 6

Which would never pass the min size logic as it exists (because they are discontinuous) and would wait until the final need to clear the cache, which will throw an error during the write step.

Does this just functionally never happen? Or am I misinterpreting other parts of the logic?

I ask because we also will have some rather large datasets that may be aided by cloud runs, I'd love to help add a solution that provides a lot of utility.

tchaton commented 4 months ago

Hey @sritterginkgo Actually, I think my code is wrong anyway because the min_index doesn't start from 0 and it is based on the received indexes. So you receive, 2, 4, 5, 0, then it won't behave properly.

Let's go with your initial solution. If users want to keep the order, they can keep an index on the side or store it within the sample.

sritterginkgo commented 4 months ago

146

tchaton commented 4 months ago

Closing from https://github.com/Lightning-AI/litdata/pull/146