🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
print(f"Fetching index {index}")
return self.data[index]
def __len__(self):
return len(self.data)
# Sample data
data = list(range(10))
dataset = MyDataset(data)
# DataLoader with batch_size=1 and num_workers=0
dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
for batch in dataloader:
print(batch)
Using accelerate
from accelerate import Accelerator
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
print(f"Fetching index {index}")
return self.data[index]
def __len__(self):
return len(self.data)
data = list(range(10))
dataset = MyDataset(data)
accelerator = Accelerator()
dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
dataloader = accelerator.prepare(dataloader)
for epoch in range(1):
for batch in dataloader:
print(batch)
The output is as follows, with two fetches performed in the first batch.
Fetching index 0
Fetching index 1
tensor([0], device='cuda:0')
Fetching index 2
tensor([1], device='cuda:0')
Fetching index 3
tensor([2], device='cuda:0')
Fetching index 4
tensor([3], device='cuda:0')
Fetching index 5
tensor([4], device='cuda:0')
Fetching index 6
tensor([5], device='cuda:0')
Fetching index 7
tensor([6], device='cuda:0')
Fetching index 8
tensor([7], device='cuda:0')
Fetching index 9
tensor([8], device='cuda:0')
tensor([9], device='cuda:0')
Hi @qsunyuan, thanks for reporting ! This is normal since we iterate one batch ahead to check when we are at the end. See here. This can potentially be solved cc @muellerzr . Do you have a specific use case where you modify the getter of the dataset ?
Default:
Using
accelerate
The output is as follows, with two fetches performed in the first batch.