Lightning-AI / litdata

Streamline data pipelines for AI. Process datasets across 1000s of machines, and optimize data for blazing fast model training.
Apache License 2.0
249 stars 24 forks source link

Index Error when calling StreamingDataLoader.state_dict() when using custom collate_fn with multiple workers #196

Closed esivonxay-cognitiv closed 1 day ago

esivonxay-cognitiv commented 2 days ago

πŸ› Bug

To Reproduce

Steps to reproduce the behavior: Add this unit test to the test_dataloader.py file and run it.

def test_custom_collate_multiworker():
    dataset = TestCombinedStreamingDataset(
        [TestStatefulDatasetDict(10, 1), TestStatefulDatasetDict(10, -1)],
        42,
        weights=(0.5, 0.5),
        iterate_over_all=False,
    )
    assert dataset._datasets[0].shuffle is None
    assert dataset._datasets[1].shuffle is None
    dataloader = StreamingDataLoader(dataset, batch_size=2, num_workers=3, shuffle=True, collate_fn=custom_collate_fn)
    assert dataset._datasets[0].shuffle
    assert dataset._datasets[1].shuffle
    dataloader_iter = iter(dataloader)
    assert next(dataloader_iter) == "received"
    assert dataloader._num_samples_yielded_combined[0] == [2]
    assert next(dataloader_iter) == "received"
    assert next(dataloader_iter) == "received"
    assert next(dataloader_iter) == "received"

    dataloader.state_dict()

Expected behavior

The state_dict() method should execute without any errors.

Environment

github-actions[bot] commented 2 days ago

Hi! thanks for your contribution!, great first issue!

tchaton commented 1 day ago

Hey @esivonxay-cognitiv, Would you be interested in attempting a fix and submitting a PR ?

esivonxay-cognitiv commented 1 day ago

Yeah, i'll give it a shot