Closed esivonxay-cognitiv closed 1 day ago
Steps to reproduce the behavior: Add this unit test to the test_dataloader.py file and run it.
test_dataloader.py
def test_custom_collate_multiworker(): dataset = TestCombinedStreamingDataset( [TestStatefulDatasetDict(10, 1), TestStatefulDatasetDict(10, -1)], 42, weights=(0.5, 0.5), iterate_over_all=False, ) assert dataset._datasets[0].shuffle is None assert dataset._datasets[1].shuffle is None dataloader = StreamingDataLoader(dataset, batch_size=2, num_workers=3, shuffle=True, collate_fn=custom_collate_fn) assert dataset._datasets[0].shuffle assert dataset._datasets[1].shuffle dataloader_iter = iter(dataloader) assert next(dataloader_iter) == "received" assert dataloader._num_samples_yielded_combined[0] == [2] assert next(dataloader_iter) == "received" assert next(dataloader_iter) == "received" assert next(dataloader_iter) == "received" dataloader.state_dict()
The state_dict() method should execute without any errors.
state_dict()
conda
pip
Hi! thanks for your contribution!, great first issue!
Hey @esivonxay-cognitiv, Would you be interested in attempting a fix and submitting a PR ?
Yeah, i'll give it a shot
π Bug
To Reproduce
Steps to reproduce the behavior: Add this unit test to the
test_dataloader.py
file and run it.Expected behavior
The
state_dict()
method should execute without any errors.Environment
conda
,pip
, source): pip