huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.32k stars 872 forks source link

Why is there a double fetch in the first batch when using accelerate?" #2884

Open qsunyuan opened 1 week ago

qsunyuan commented 1 week ago

Default:

import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        print(f"Fetching index {index}")
        return self.data[index]

    def __len__(self):
        return len(self.data)

# Sample data
data = list(range(10))
dataset = MyDataset(data)

# DataLoader with batch_size=1 and num_workers=0
dataloader = DataLoader(dataset, batch_size=1, num_workers=0)

for batch in dataloader:
    print(batch)

Using accelerate

from accelerate import Accelerator
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        print(f"Fetching index {index}")
        return self.data[index]

    def __len__(self):
        return len(self.data)

data = list(range(10))
dataset = MyDataset(data)

accelerator = Accelerator()

dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
dataloader = accelerator.prepare(dataloader)

for epoch in range(1):  
    for batch in dataloader:
        print(batch)

The output is as follows, with two fetches performed in the first batch.

Fetching index 0
Fetching index 1
tensor([0], device='cuda:0')
Fetching index 2
tensor([1], device='cuda:0')
Fetching index 3
tensor([2], device='cuda:0')
Fetching index 4
tensor([3], device='cuda:0')
Fetching index 5
tensor([4], device='cuda:0')
Fetching index 6
tensor([5], device='cuda:0')
Fetching index 7
tensor([6], device='cuda:0')
Fetching index 8
tensor([7], device='cuda:0')
Fetching index 9
tensor([8], device='cuda:0')
tensor([9], device='cuda:0')
SunMarc commented 1 week ago

Hi @qsunyuan, thanks for reporting ! This is normal since we iterate one batch ahead to check when we are at the end. See here. This can potentially be solved cc @muellerzr . Do you have a specific use case where you modify the getter of the dataset ?