pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.08k stars 142 forks source link

StatefulDataLoader stores worker state twice if the IterableDataset is also an Iterator #1259

Closed gokulavasan closed 1 month ago

gokulavasan commented 1 month ago

🐛 Describe the bug

`` class MyIterabledataset(torch.utils.data.IterableDataset, Iterator, Stateful):

def __init__(self, samples):
    self.samples = samples
    self.size = len(self.samples)
    self.i = 0

def __iter__(self):
    return self

def __next__(self):
    if self.i >= len(self.samples):
        raise StopIteration
    else:
        i = self.i
    sample = self.samples[i]
    self.i += 1
    return sample

def state_dict(self):
    return {"i": self.i}

def load_state_dict(self, state_dict):
    self.i = state_dict["i"]

``

In the above example, the state will be stored in dataset_state (https://github.com/pytorch/data/blob/a0412de86211f30d9c79acbb5dae73fce23e5739/torchdata/stateful_dataloader/worker.py#L219) as well as fetcher_state (https://github.com/pytorch/data/blob/a0412de86211f30d9c79acbb5dae73fce23e5739/torchdata/stateful_dataloader/worker.py#L213).

State can be quite expensive to store and transfer and thus it would be good to avoid replicating it in this scenario.

Versions

main branch