Lightning-AI / litdata

Streamline data pipelines for AI. Process datasets across 1000s of machines, and optimize data for blazing fast model training.
Apache License 2.0
249 stars 24 forks source link

train_test_split fails when asked for `splits=[0.1, 0.2, 0.7]` #186

Closed deependujha closed 4 days ago

deependujha commented 4 days ago

🐛 Bug

train_test_split works perfectly when asked to split dataset in splits=[0.1, 0.7, 0.2], but it fails when asked for splits=[0.1, 0.2, 0.7].

To Reproduce

Try this script:

import os
from litdata import optimize, train_test_split, StreamingDataset, StreamingDataLoader

x, y, z = train_test_split(streaming_dataset=StreamingDataset("output_dir"), splits=[0.1, 0.2, 0.7])

print(f"{len(x)=}")
print(f"{len(y)=}")
print(f"{len(z)=}")

print(f"{x[:]=}")
print(f"{y[:]=}")
print(f"{z[:]=}") # this will raise error

x = StreamingDataLoader(x, batch_size=5)
y = StreamingDataLoader(y, batch_size=5)
z = StreamingDataLoader(z, batch_size=5)

print("-"*80)
print("iterate X")
for _x in x:
    print(_x)

print("-"*80)
print("iterate Y")
for _y in y:
    print(_y)

print("-"*80)
print("iterate Z")
for _z in z: # this will raise error
    print(_z)

print("-"*80)
print("All done!")

Code sample

Code for output_dir:

import os
from litdata import optimize, train_test_split, StreamingDataset

def compress(index):
    return (index, index ** 2)

optimize(
    fn=compress,
    inputs=list(range(100)),
    num_workers=4,
    output_dir="output_dir",
    chunk_bytes="64MB",
    mode="overwrite",
)

Expected behavior

It should work irrespective of their order.

Environment

Additional context

It's happening bcoz of some logic issue in def subsample_filenames_and_roi().

github-actions[bot] commented 4 days ago

Hi! thanks for your contribution!, great first issue!