huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.97k stars 970 forks source link

Wrong epoch when resuming from checkpoint #3242

Open xiechun-tsukuba opened 4 days ago

xiechun-tsukuba commented 4 days ago

System Info

- `Accelerate` version: 1.1.1
- Platform: Linux-3.10.0-1160.108.1.el7.x86_64-x86_64-with-glibc2.17
- `accelerate` bash location: /home/xiechun/micromamba/envs/PixArt/bin/accelerate
- Python version: 3.11.9
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.3.1 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch MUSA available: False
- System RAM: 186.33 GB
- GPU type: Tesla V100-PCIE-32GB
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: MULTI_GPU
        - mixed_precision: fp16
        - use_cpu: False
        - debug: False
        - num_processes: 4
        - machine_rank: 0
        - num_machines: 1
        - gpu_ids: all
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - enable_cpu_affinity: False
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []

Information

Tasks

Reproduction

This is highly related to #2823, which is closed as completed by bot, but I don't think the problem has been solved. I encountered the same problem and would like to provide some more information.

As described here, the dataloader resumed from checkpoint should know which epoch it should start with, but as reported by #2823, the sequence the dataloader yields is actually the epoch n+1 if the training was interrupted during epoch n.

I reproduced the bug using the code below. When num_processes=1, the behavior is the same as that of #2823. I saved the checkpoint in the middle of epoch 2, but the dataloader starts from epoch 3 after resuming. However, When num_processes>1, the dataloader is not resumed at all, it begins from epoch 1.

Here is the code I used to reproduce the bug.

import torch
from torch.utils.data import DataLoader, Dataset

from accelerate import Accelerator
from accelerate.utils import set_seed, DataLoaderConfiguration
import os

# from torchdata.stateful_dataloader import StatefulDataLoader

# Simple dataset with 10 elements
class SimpleDataset(Dataset):
    def __init__(self):
        self.data = list(range(10))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Function to print batch order in an epoch
def print_epoch_batches(epoch, dataloader, interrupt=False):

    output = f"\n--- Process {accelerator.process_index} ---\n"
    output += f"Epoch {epoch + 1}:\n"
    data = []
    for i, batch in enumerate(dataloader):
        data.append(batch.tolist())
        if interrupt and epoch == 1 and i == 1:
            accelerator.save_state(output_dir="debug_data_order_checkpoints")
            output += f"Random state saved\n"
    output += f"{data}\n"
    return output

if __name__ == "__main__":
    accelerator = Accelerator()

    set_seed(42)

    # Create the dataset and DataLoader with shuffle=True
    dataset = SimpleDataset()
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
    )
    dataloader = accelerator.prepare(dataloader)

    # Check data order for 3 epochs
    all_outputs = []
    for epoch in range(3):
        epoch_output = print_epoch_batches(epoch, dataloader, interrupt=True)
        all_outputs.append(epoch_output)

    # Print all outputs at the end
    for output in all_outputs:
        print(output)

    accelerator.wait_for_everyone()

    # Resume from checkpoint
    accelerator = Accelerator(
        dataloader_config=dataloader_config,
    )
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
    )

    dataloader = accelerator.prepare(dataloader)

    # Load random state
    if os.path.exists("debug_data_order_checkpoints"):
        accelerator.load_state("debug_data_order_checkpoints")
        print("Random state loaded from debug_data_order_checkpoints")

    # skip_dataloader = accelerator.skip_first_batches(dataloader, 2)

    # Check data order for 1 epoch
    all_outputs = []
    for epoch in range(1):
        epoch_output = print_epoch_batches(epoch, dataloader)
        all_outputs.append(epoch_output)

    # Print all outputs at the end
    for output in all_outputs:
        print(output)

when num_processes=1, the outputs are

--- Process 0 ---
Epoch 1:
[[8], [0], [9], [1], [4], [5], [6], [3], [2], [7]]

--- Process 0 ---
Epoch 2:
Random state saved
[[4], [6], [9], [3], [7], [5], [2], [0], [1], [8]]

--- Process 0 ---
Epoch 3:
[[6], [2], [1], [0], [3], [8], [5], [9], [4], [7]]

Random state loaded from debug_data_order_checkpoints

--- Process 0 ---
Epoch 1:
[[6], [2], [1], [0], [3], [8], [5], [9], [4], [7]]

You can see the resumed epoch is identical to epoch 3.

when num_processes=2, the outputs are

--- Process 1 ---
Epoch 1:
[[3], [5], [9], [1], [6]]

--- Process 1 ---
Epoch 2:
Random state saved
[[4], [9], [1], [7], [8]]

--- Process 1 ---
Epoch 3:
[[5], [2], [8], [7], [6]]

--- Process 0 ---
Epoch 1:
[[0], [7], [2], [8], [4]]

--- Process 0 ---
Epoch 2:
Random state saved
[[6], [2], [5], [0], [3]]

--- Process 0 ---
Epoch 3:
[[4], [9], [3], [1], [0]]

Random state loaded from debug_data_order_checkpoints
Random state loaded from debug_data_order_checkpoints

--- Process 1 ---
Epoch 1:
[[3], [5], [9], [1], [6]]

--- Process 0 ---
Epoch 1:
[[0], [7], [2], [8], [4]]

In this case, the resumed epoch is identical to epoch 1.

The only way I found to make it work as expected is by setting use_seedable_sampler=True in DataLoaderConfiguration and then set_seed(seed=42) followed by dataloader.set_epoch(1). However, I don't know if this is save because I didn't see set_epoch() documented anywhere, I randomly found it in a PR. Moreover, I believe that beginning from the interrupted epoch after resuming is supposed to be the default behavior ( as mentioned here) w/o extra configuration and regardless of the number of gpus, otherwise, all the examples in transformers and diffusers using accelerate.skip_first_batches are wrong as none if them can resume the correct epoch.

Expected behavior

The dataloader should start from the correct epoch after resuming, for both num_processes = 1 and num_processes > 1