huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.34k stars 875 forks source link

Dataloader yields wrong sequence when resuming training #2823

Open lolalebreton opened 1 month ago

lolalebreton commented 1 month ago

System Info

- `Accelerate` version: 0.29.3
- Platform: Linux-5.15.0-101-generic-x86_64-with-glibc2.35
- `accelerate` bash location: /(...)/.venv/bin/accelerate
- Python version: 3.10.11
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.2.2+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- System RAM: 377.53 GB
- GPU type: Quadro RTX 8000
- `Accelerate` default config:
        Not found

Information

Tasks

Reproduction

This is about training with a dataloader which shuffles at every epoch. For reproducibility, when resuming training, the dataloader's order should be identical to the one from the epoch where training was interrupted. However, setting train_dataloader.set_epoch(epoch) has zero effect (no change on the sequence yielded no matter the value of the epoch used).

The sequence the dataloader yields is actually the epoch n+1 if training was interrupted during epoch n.

Here is a minimal example of outputs for a DataLoader(list(range(10)), shuffle=True, batch_size=4)

Without resuming:

Epoch: 0 [6, 7, 1, 4] [2, 0, 9, 8] [3, 5]

Epoch: 1 [2, 4, 7, 0] [8, 9, 5, 3] [6, 1]

Epoch: 2 [7, 4, 5, 1] [9, 3, 8, 2] [0, 6]

With resuming after two steps:

Epoch: 0 [6, 7, 1, 4] [2, 0, 9, 8] *interuption

*resuming [6, 1]

Epoch: 1 [7, 4, 5, 1] [9, 3, 8, 2] [0, 6]

Code to reproduce

import os
import re
import hydra
from omegaconf import DictConfig
import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, set_seed
from collections import defaultdict

class Metrics(defaultdict):

    def __init__(self):
        super().__init__(int)

    def state_dict(self):
        return dict(self)

    def load_state_dict(self, state_dict):
        for k, v in state_dict.items():
            self[k] = v

    def log(self, accelerator: Accelerator):
        # Build the metrics to log
        metrics_log = dict()
        metrics_log["train/epochs"] = self["train/epochs"]
        metrics_log["train/steps"] = self["train/steps"]

        # Log the metrics
        accelerator.log(metrics_log)

def write_to_file(path, accelerator, obj):
    if accelerator.is_main_process:
        with open(path, "a") as f:
            f.write(obj)
            f.write("\n")

@hydra.main(version_base=None, config_path="../conf", config_name="config")
def train(cfg: DictConfig):
    # Get the last checkpoint id
    checkpoint_dir = os.path.join(cfg.trainer.dir, "checkpoints")
    iteration = 0
    if cfg.trainer.resume and os.path.exists(checkpoint_dir) and len(os.listdir(checkpoint_dir)) > 0:
        folders = os.listdir(checkpoint_dir)
        iteration = max(int(re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)[0]) for folder in folders) + 1

    # Accelerator object
    project_config = ProjectConfiguration(
        cfg.trainer.dir,
        automatic_checkpoint_naming=True,
        total_limit=50,
        iteration=iteration,
    )
    accelerator = Accelerator(
        mixed_precision="no",
        gradient_accumulation_steps=1,
        project_config=project_config,
    )

    # File to log outputs
    path = "log.txt"

    # Set the seed
    set_seed(cfg.seed)

    # Local and global counters
    metrics = Metrics()
    accelerator.register_for_checkpointing(metrics)

    train_dataloader = DataLoader(list(range(10)), shuffle=True, batch_size=4)

    # Accelerate
    train_dataloader = accelerator.prepare(train_dataloader)

    # Resume from the latest checkpoint
    skipped_train_dataloader = None
    if cfg.trainer.resume and os.path.exists(checkpoint_dir) and len(os.listdir(checkpoint_dir)) > 0:
        accelerator.load_state()
        if accelerator.is_main_process:
            write_to_file(path, accelerator, "\nResuming in epoch: " + str(metrics["train/epochs"]))
        train_dataloader.set_epoch(metrics["train/epochs"])
        skipped_train_dataloader = accelerator.skip_first_batches(train_dataloader, metrics["train/batches"] % len(train_dataloader))

    while cfg.trainer.max_steps > metrics["train/steps"]:
        # Use skipped_train_dataloader the first epoch after resuming
        dataloader = train_dataloader if skipped_train_dataloader is None else skipped_train_dataloader

        write_to_file(path, accelerator, "\nEpoch: " + str(metrics["train/epochs"]))
        for batch in dataloader:
            # Update number of batches
            metrics["train/batches"] += 1

            write_to_file(path, accelerator, "\nSteps: " + str(metrics["train/steps"]))
            write_to_file(path, accelerator, str(torch.flatten(batch).tolist()))

            metrics["train/steps"] += 1
            accelerator.save_state()

            if metrics["train/steps"] >= cfg.trainer.max_steps:
                break

        # Log metrics
        metrics["train/epochs"] += 1

        # "Remove" the skipped dataloader once exhausted
        skipped_train_dataloader = None

    # Make sure that the wandb tracker finishes correctly and close the progress bar
    accelerator.end_training()

if __name__ == "__main__":
    train()

Expected behavior

The dataloader should yield identical sequences with or without resuming.

muellerzr commented 1 month ago

Can you try updating your accelerate version to see if we fixed it in the prior releases?

lolalebreton commented 1 month ago

Hi! Thank you for your answer. It gives the same results with accelerate 0.30.1

lolalebreton commented 2 weeks ago

Hello! I am gently uping this issue to know if you have had a chance to look into it?

BenjaminBossan commented 1 week ago

Sorry for the delay. Zach is currently out of office but I'm sure he'll look into it when he's back.