pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.12k stars 149 forks source link

Invalidate `DataLoader2Iterator` would potentially violate other iterators by `finalize_iteration` reading service #820

Open ejguan opened 1 year ago

ejguan commented 1 year ago

🐛 Describe the bug

We do finalize_iteration when __next__ is called for the invalid DataLoader2Iterator (let's call it Iter_A):

https://github.com/pytorch/data/blob/c42587a828d05f24f6f0586d17d3e9d55e1433ed/torchdata/dataloader2/dataloader2.py#L67-L68

However, as a new DataLoader2Iterator (Iter_B) has been created and the same ReadingService has already called initialize here: https://github.com/pytorch/data/blob/c42587a828d05f24f6f0586d17d3e9d55e1433ed/torchdata/dataloader2/dataloader2.py#L135

When someone called next(Iter_A) afterwards, it will finalize_iteration of the same ReadingService. Then, the Iter_B would exit earlier than expected.

Solution

We should move all logic from __next__ to __init__ for DataLoader2Iterator

cc: @NivekT

Versions

main branch

NivekT commented 1 year ago

I think the solution might be:

  1. Call finalize_iteration within DataLoader2's __iter__ or DataLoader2Iterator's __init__ to ensure that, if there is an existing RS reading through, that gets terminated properly
    • Thoughts on whether this is necessary?
  2. We still should check when __next__ is called whether the iterator is valid, otherwise we can't know?
  3. When the iterator is invalid, don't call finalize_iteration, just raise exception

Thoughts?

ejguan commented 1 year ago

Thinking a little bit on those options. I think the only possible solution might be always calling finalize_iteration before initialize_iteration in DL2's __iter__. In the following case with the current implementation:

iter_a = iter(dl)
iter_b = iter(dl)  # `initialize_iteration` called first and we never properly call `finalize_iteration` beforehand

So, it might be better to call finalize_iteration if initialize_iteration has been called before

NivekT commented 1 year ago

Agree with that.