google / grain

Apache License 2.0
288 stars 20 forks source link

Handling Per-Epoch Training with Grain Dataloader #572

Open danbnyn opened 1 month ago

danbnyn commented 1 month ago

Hi there,

I’m currently working with the Grain Dataloader and encountered an issue with stopping and resuming training at each epoch boundary when initializing the sampler with num_epochs > 0. I’d like some guidance on how to handle this scenario properly and understand the shuffling mechanism across epochs.

The Issue:

I’m using the IndexSampler with num_epochs > 0, and my expectation is to stop the dataloader at the end of each epoch and resume training at the next epoch. However, the sampler seems to provide a total number of items equal to num_samples * num_epochs, making it behave as if all epochs are combined into a single stream.

Here’s how I understand the current behavior:

My Goal:

I want to stop at the end of the first epoch and then resume at the next epoch. My current implementation tracks the iterator state using get_state() and set_state(), but this feels somewhat clunky.

What I’ve Tried:

  1. Tracking the iterator state:

    • After each batch is processed, I use get_state() to check the last_seen_indices.
  2. Checking for the epoch boundary:

    • I attempt to detect when the indices reach the end of the current epoch by comparing the last_seen_indices to a calculated boundary value based on (current_epoch + 1) * num_records.
  3. State update upon crossing the boundary:

    • Once the boundary is crossed, I update the state by setting the indices 'last_seen_indices' just before the start of the next epoch for each worker and resetting last_worker_index to ensure workers resume correctly.
  4. Training loop logic:

    • After each batch, I check the epoch boundary. If it is crossed, I break the loop, save the state, and later resume from this state.

However, this approach feels a bit manual, and I’m wondering if there’s a more optimal or intended way to handle this scenario, especially in terms of the shuffling behavior across epochs and stopping/resuming the dataloader.

Questions:

  1. Is there a cleaner way to stop the dataloader at the end of each epoch and resume from the next?
  2. Is my approach of tracking the iterator state with get_state() and set_state() appropriate for this use case, or is there a better approach?

Thank you for your assistance!