Lightning-AI / litdata

Transform datasets at scale. Optimize datasets for fast AI model training.
Apache License 2.0
362 stars 41 forks source link

incorrect dataloader length when `drop_last=False` #402

Open grez72 opened 1 week ago

grez72 commented 1 week ago

🐛 Bug

When drop_last=False, len(StreamingDataLoader) returns the incorrect length if batch_size does not divide evenly into len(dataset). It appears to return ceil(length / self.batch_size), but the actual length is greater than this and depends on the num_workers (apparently each worker returns a final batch that's < batch_size). One consequence is that the number of full batches (where actual_batch_size == dataloader.batch_size) is less than dataset.length // batch_size.

I noticed this because I use fastprogress.progress_bar instead of tqdm, and that progress_bar appears to check len(dataloader) to determine the total number of items to iterate over, and consequently drops the extra partial batches. So I was expecting to iterate over the full imagenet validation set (50000 samples), but was only iterating over 49432 samples even though I set drop_last=False.

To Reproduce

Steps to reproduce the behavior...

Code sample ## generate a fake dataset for testing ```python import os, io import numpy as np from PIL import Image import litdata as ld def random_images_jpeg_encode(index): fake_images = Image.fromarray(np.random.randint(0, 256, (224, np.random.choice([224,320,384]), 3), dtype=np.uint8)) fake_labels = np.random.randint(10) image_bytes = io.BytesIO() fake_images.save(image_bytes, format="JPEG", quality=100, optimize=True) image_bytes.seek(0) # You can use any key:value pairs. Note that their types must not change between samples, and Python lists must # always contain the same number of elements with the same types. data = {"index": index, "image": image_bytes.read(), "label": fake_labels} return data ld.optimize( fn=random_images_jpeg_encode, # the function applied to each input inputs=list(range(50000)), # the inputs to the function (here it's a list of numbers) output_dir="fast_data", # optimized data is stored here num_workers=4, # The number of workers on the same machine chunk_bytes="64MB" # size of each chunk ) ``` ## helpers for testing iteration over fake dataset ```python import os import torch from tqdm import tqdm from fastprogress import progress_bar from litdata import StreamingDataset, StreamingDataLoader from litdata.streaming.serializers import JPEGSerializer import torchvision.transforms.v2 as T2 from pdb import set_trace serializer = JPEGSerializer() class ImageNetStreamingDataset(StreamingDataset): def __init__(self, *args, **kwargs): self.transform = T2.Compose([ lambda img_bytes: serializer.deserialize(img_bytes), T2.RandomResizedCrop(224, antialias=True), T2.RandomHorizontalFlip(p=.5), T2.ToImage(), T2.ToDtype(torch.float16, scale=True), ]) super().__init__(*args, **kwargs) def __getitem__(self, idx): # Note: If torchvision is installed, we return a tensor image instead of a pil image as it is much faster. sample = super().__getitem__(idx) # <- Whatever you returned from the DatasetOptimizer prepare_item method. sample['image'] = self.transform(sample['image']) return sample def get_dataloader(input_dir, num_workers, batch_size, drop_last): dataset = ImageNetStreamingDataset(input_dir, shuffle=False, drop_last=drop_last) print(f"Length of dataset: {len(dataset)}") dataloader = StreamingDataLoader(dataset, num_workers=num_workers, batch_size=batch_size, profile_batches=False, shuffle=False, drop_last=drop_last) print(f"Length of dataloader: {len(dataloader)}") return dataloader def iterate_dataloader(dataloader, pbar): # iterate over dataloader image_count = 0 batch_count = 0 full_batch_count = 0 partial_batch_sizes = [] for batch_num,sample in enumerate(pbar(dataloader)): batch_count += 1 image_count+=sample['image'].shape[0] bs = sample['image'].shape[0] if bs != dataloader.batch_size: partial_batch_sizes.append(bs) else: full_batch_count+=1 print(f"batch_size: {dataloader.batch_size}") print(f"num_workers: {dataloader.num_workers}") if len(dataloader) != batch_count: print(f"\u274C len(dataloader) = {len(dataloader)}, actual num_batches = {batch_count}") else: print(f"\u2705 len(dataloader) = {len(dataloader)}, actual num_batches = {batch_count}") if image_count != len(dataloader.dataset): print(f"\u274C Actual number of images: {image_count}") else: print(f"\u2705 Actual number of images: {image_count}") print(f"Number of full batches (img_count == {dataloader.batch_size}): {full_batch_count}") print(f"Number partial batches (img_count < {dataloader.batch_size}): {len(partial_batch_sizes)}") print(f"Sizes of partial batches: {partial_batch_sizes}") ``` ## test with tqdm You'll see that the len(dataloader) is not match the actual number of batches, but tqdm still iterates over the full dataset (a bunch of partial batches, one per worker). ```python dataloader = get_dataloader(input_dir='fast_data', num_workers = 12, batch_size = 256, drop_last = False) iterate_dataloader(dataloader, tqdm) ``` Length of dataset: 50000 Length of dataloader: 196 batch_size: 256 num_workers: 12 ❌ len(dataloader) = 196, actual num_batches = 204 ✅ Actual number of images: 50000 Number of full batches (img_count == 256): 192 Number partial batches (img_count < 256): 12 Sizes of partial batches: [70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 78] ## test with fastprogress.progress_bar fastprogress.progress_bar stops early (after len(dataloader) batches), dropping most of the partial batches. ```python dataloader = get_dataloader(input_dir='fast_data',num_workers = 12, batch_size = 256, drop_last = False) iterate_dataloader(dataloader, progress_bar) ``` Length of dataset: 50000 Length of dataloader: 196 batch_size: 256 num_workers: 12 ✅ len(dataloader) = 196, actual num_batches = 196 ❌ Actual number of images: 49432 Number of full batches (img_count == 256): 192 Number partial batches (img_count < 256): 4 Sizes of partial batches: [70, 70, 70, 70]

Expected behavior

I would expect len(dataloader) to return the actual number of batches that will be yielded when iterating over the dataloader.

I would also have expected there to be only one "partial batch" that's less than the total batch size (similar to the behavior seen with the torchvision DataLoader). So for the examples above, I would expect 195 batches of size 256, and a single partial batch of size 80 (195*256+80 = 50,000).

Additional context

latest litdata

github-actions[bot] commented 1 week ago

Hi! thanks for your contribution!, great first issue!