I’m currently working with the Grain Dataloader and encountered an issue with stopping and resuming training at each epoch boundary when initializing the sampler with num_epochs > 0. I’d like some guidance on how to handle this scenario properly and understand the shuffling mechanism across epochs.
The Issue:
I’m using the IndexSampler with num_epochs > 0, and my expectation is to stop the dataloader at the end of each epoch and resume training at the next epoch. However, the sampler seems to provide a total number of items equal to num_samples * num_epochs, making it behave as if all epochs are combined into a single stream.
Here’s how I understand the current behavior:
The IndexSampler creates a MapDataset with the number of records (num_records). This acts like an infinite list of record keys stitched together.
When calling __getitem__ with an index greater than num_records (but less than num_epochs * num_records), it acts like a modulus operation. This mechanism helps maintain access to one record per epoch.
Shuffling appears to happen per slice of num_records, allowing random access across epochs while maintaining the structure.
My Goal:
I want to stop at the end of the first epoch and then resume at the next epoch. My current implementation tracks the iterator state using get_state() and set_state(), but this feels somewhat clunky.
What I’ve Tried:
Tracking the iterator state:
After each batch is processed, I use get_state() to check the last_seen_indices.
Checking for the epoch boundary:
I attempt to detect when the indices reach the end of the current epoch by comparing the last_seen_indices to a calculated boundary value based on (current_epoch + 1) * num_records.
State update upon crossing the boundary:
Once the boundary is crossed, I update the state by setting the indices 'last_seen_indices' just before the start of the next epoch for each worker and resetting last_worker_index to ensure workers resume correctly.
Training loop logic:
After each batch, I check the epoch boundary. If it is crossed, I break the loop, save the state, and later resume from this state.
However, this approach feels a bit manual, and I’m wondering if there’s a more optimal or intended way to handle this scenario, especially in terms of the shuffling behavior across epochs and stopping/resuming the dataloader.
Questions:
Is there a cleaner way to stop the dataloader at the end of each epoch and resume from the next?
Is my approach of tracking the iterator state with get_state() and set_state() appropriate for this use case, or is there a better approach?
Hi there,
I’m currently working with the Grain Dataloader and encountered an issue with stopping and resuming training at each epoch boundary when initializing the sampler with
num_epochs > 0
. I’d like some guidance on how to handle this scenario properly and understand the shuffling mechanism across epochs.The Issue:
I’m using the
IndexSampler
withnum_epochs > 0
, and my expectation is to stop the dataloader at the end of each epoch and resume training at the next epoch. However, the sampler seems to provide a total number of items equal tonum_samples * num_epochs
, making it behave as if all epochs are combined into a single stream.Here’s how I understand the current behavior:
IndexSampler
creates aMapDataset
with the number of records (num_records
). This acts like an infinite list of record keys stitched together.__getitem__
with an index greater thannum_records
(but less thannum_epochs * num_records
), it acts like a modulus operation. This mechanism helps maintain access to one record per epoch.num_records
, allowing random access across epochs while maintaining the structure.My Goal:
I want to stop at the end of the first epoch and then resume at the next epoch. My current implementation tracks the iterator state using
get_state()
andset_state()
, but this feels somewhat clunky.What I’ve Tried:
Tracking the iterator state:
get_state()
to check thelast_seen_indices
.Checking for the epoch boundary:
last_seen_indices
to a calculated boundary value based on(current_epoch + 1) * num_records
.State update upon crossing the boundary:
last_worker_index
to ensure workers resume correctly.Training loop logic:
However, this approach feels a bit manual, and I’m wondering if there’s a more optimal or intended way to handle this scenario, especially in terms of the shuffling behavior across epochs and stopping/resuming the dataloader.
Questions:
get_state()
andset_state()
appropriate for this use case, or is there a better approach?Thank you for your assistance!