Lightning-AI / litdata

Streamline data pipelines for AI. Process datasets across 1000s of machines, and optimize data for blazing fast model training.
Apache License 2.0
249 stars 24 forks source link

Fix: unexpected behaviours (bugs) in train_test_split fixed #192

Closed deependujha closed 3 days ago

deependujha commented 3 days ago
Before submitting - [ ] Was this discussed/agreed via a Github issue? (no need for typos and docs improvements) - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/lit-data/blob/main/.github/CONTRIBUTING.md), Pull Request section? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests?

What does this PR do?

train_test_split works perfectly when asked to split dataset in splits=[0.1, 0.7, 0.2], but it fails when asked for splits=[0.1, 0.2, 0.7].

In the original code, this code will fail:

import os
from litdata import optimize, train_test_split, StreamingDataset, StreamingDataLoader

x, y, z = train_test_split(streaming_dataset=StreamingDataset("output_dir"), splits=[0.1, 0.2, 0.7])

print(f"{len(x)=}")
print(f"{len(y)=}")
print(f"{len(z)=}")

print(f"{x[:]=}")
print(f"{y[:]=}")
print(f"{z[:]=}") # this will raise error

x = StreamingDataLoader(x, batch_size=5)
y = StreamingDataLoader(y, batch_size=5)
z = StreamingDataLoader(z, batch_size=5)

print("-"*80)
print("iterate X")
for _x in x:
    print(_x)

print("-"*80)
print("iterate Y")
for _y in y:
    print(_y)

print("-"*80)
print("iterate Z")
for _z in z: # this will raise error
    print(_z)

print("-"*80)
print("All done!")

Except the failure, if you look at the values printed by y[:], it overlaps with x[:]. This was bcoz of the way reader was reading from chunks.


Code for output_dir

import os
from litdata import optimize, train_test_split, StreamingDataset

def compress(index):
    return (index, index ** 2)

optimize(
    fn=compress,
    inputs=list(range(100)),
    num_workers=4,
    output_dir="output_dir",
    chunk_bytes="64MB",
    mode="overwrite",
)

These bugs have been fixed in this PR. This PR originally aimed at closing a issue #186 , but it has been closed already, bcoz of some confusion.

PR review

Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

codecov[bot] commented 3 days ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Please upload report for BASE (main@f2c5a7b). Learn more about missing BASE report.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #192 +/- ## ===================================== Coverage ? 78% ===================================== Files ? 33 Lines ? 4488 Branches ? 0 ===================================== Hits ? 3492 Misses ? 996 Partials ? 0 ```
tchaton commented 3 days ago

Hey @deependujha. Can you describe which bugs this is fixing ?

deependujha commented 3 days ago

Hey @deependujha. Can you describe which bugs this is fixing ?

Sorry for the delay in response. I've updated the description. Plz have a look at it. It's an extension of a PR that has been merged already (#187 )