Lightning-AI / litdata

Transform datasets at scale. Optimize datasets for fast AI model training.
Apache License 2.0
250 stars 24 forks source link

Using a streaming dataloader with an unbalanced dataset yields unexpected batch sizes. #199

Open esivonxay-cognitiv opened 6 days ago

esivonxay-cognitiv commented 6 days ago

🐛 Bug

I have two datasets which are unbalanced, where one dataset is 1000x larger than the other. I would like to sample from two of the datasets such that the ratio of samples from each is 1:100. When doing so, the batches are of irregular size are returned during iteration.

I think there are 2 issues which this test surfaces: 1) The first batch returned by each worker is not properly sized. 2) drop_last does not appear to work as intended, since the last batch is not a full sized batch

I don't think this is related to #179, but it's possible

I've been attempting to fix this, but I'm not sure what the root of the issue is. I would be very appreciative if you could fix this or point me in the right direction.

Thanks!

To Reproduce

@pytest.mark.skipif(sys.platform == "win32", reason="too slow in CI")
def test_unbalanced_combined_dataset_with_dataloader(tmpdir):
    data_dir_1 = os.path.join(tmpdir, "data_1")
    data_dir_2 = os.path.join(tmpdir, "data_2")
    cache_dir_1 = os.path.join(tmpdir, "cache_dir_1")
    cache_dir_2 = os.path.join(tmpdir, "cache_dir_2")

    os.makedirs(data_dir_1)
    os.makedirs(data_dir_2)
    os.makedirs(cache_dir_1)
    os.makedirs(cache_dir_2)

    cache = Cache(input_dir=str(data_dir_1), chunk_size=2)

    for i in range(10):
        cache[i] = i

    cache.done()
    cache.merge()

    cache = Cache(input_dir=str(data_dir_2), chunk_size=2)

    for i in range(10000):
        cache[i] = i + 10

    cache.done()
    cache.merge()

    dataset1 = StreamingDataset(input_dir=Dir(cache_dir_1, data_dir_1), shuffle=True)
    dataset2 = StreamingDataset(input_dir=Dir(cache_dir_2, data_dir_2), shuffle=True)
    dataset = CombinedStreamingDataset(
        datasets=[dataset1, dataset2], weights=[0.01, 0.99], iterate_over_all=False, seed=12345
    )
    dataloader = StreamingDataLoader(dataset, num_workers=3, batch_size=100, drop_last=True, persistent_workers=True, shuffle=True, prefetch_factor=2)

    assert dataset1.current_epoch == 1
    assert dataset2.current_epoch == 1

    batches_1 = []
    batch_sizes_1 = []
    for batch in dataloader:
        batch_sizes_1.append(batch.size(0))
        batches_1.append(batch)

    assert batch_sizes_1[2] == 91
    assert batch_sizes_1[-1] == 40
    # This will fail since the third and last index are not 100. (Above 2 assertions pass)
    assert batch_sizes_1 == [100 for _ in batches_1]

Expected behavior

All batch sizes should be the same.

Additional context

This issue is independent of whether drop_last, shuffle, and persistent_workers are set to True or False

tchaton commented 5 days ago

Hey @esivonxay-cognitiv, Thanks for the reproducible script. I will have a look into it.

esivonxay-cognitiv commented 5 days ago

Thanks Thomas!

tchaton commented 5 days ago

Hey @esivonxay-cognitiv I am curious, what's your interest and usage of LitData ?

esivonxay-cognitiv commented 3 days ago

Yeah, I'm interested in LitData primarily for the ability to sample from multiple streams. I've got 2 datasets which are quite imbalanced (one is 100,000x larger than the other) and I'm trying to downsample one dataset to reduce the imbalance by a couple orders of magnitude.

Naively, I could do this when constructing the dataset by throwing out datapoints. However, doing so will result in me throwing out 90 or 99% of the data (to decrease the imbalance by 10x or 100x, respectively). It's possible that important samples may be thrown out in this process.

My thought was to do this downsampling/rebalancing during dataloading so the model at least has a chance to see each sample, just at a lower rate.