Lightning-AI / litdata

Transform datasets at scale. Optimize datasets for fast AI model training.
Apache License 2.0
366 stars 42 forks source link

The tested speed is not as fast as expected. #60

Open tikboaHIT opened 8 months ago

tikboaHIT commented 8 months ago

πŸ› Bug

The tested speed is not as fast as expected.

Code sample

import os
import torch
import numpy as np
from tqdm import tqdm
from torchvision.transforms import Compose, Lambda
from litdata import StreamingDataset, StreamingDataLoader

from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo, CenterCropVideo

input_dir = 's3://extract_frames/'
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)

class ImagenetStreamingDataset(StreamingDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.transform = Compose(
            [
                Lambda(lambda x: x / 255.0),
                NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD),
                # ShortSideScale(size=224),
                CenterCropVideo(224),
            ]
        )

    def __getitem__(self, index):
        data = super().__getitem__(index)
        video_data = []
        for i in range(8):
            frame = np.array(data["image"][i])
            video_data.append(torch.from_numpy(frame).permute(2, 0, 1))
        video_data = torch.stack(video_data, dim=1)
        video_data = self.transform(video_data)
        return video_data

dataset = ImagenetStreamingDataset(input_dir, shuffle=True)
dataloader = StreamingDataLoader(dataset, batch_size=64, num_workers=8)
for batch in tqdm(dataloader, total=len(dataloader)):
    pass

Expected behavior

There are approximately 200,000 data points, each consisting of 8 frames extracted. Based on the tested speed, it should be very fast, but in reality, it is not.

Screenshot 2024-03-07 at 20 42 20

The tested speed is approximately as follows:

Screenshot 2024-03-07 at 20 48 30

Environment

github-actions[bot] commented 8 months ago

Hi! thanks for your contribution!, great first issue!

tchaton commented 8 months ago

Hey @tikboaHIT,

The benchmark are fully reproducible for Imagenet. So you can check by yourself the numbers are correct.

For your custom use cases, there is a lot of optimizations possible, especially around your transforms and going through numpy. Would you be open to create a reproducible Studio on https://lightning.ai/.

In the meanwhile, you can enable profile_batches=10 to the StreamingDataloader to check where is the time spent. When you can, could you share the trace with me, so I can help to you optimize it.

tikboaHIT commented 8 months ago

Thanks @tchaton Where should I find the generated result.json? After running the code StreamingDataLoader(dataset, batch_size=64, num_workers=16, profile_batches=5), I couldn't find the corresponding file.

tchaton commented 8 months ago

It should appear where you run the command. Maybe reduce the batch size and number of workers.

Could you provide a synthetic example for me to debug it too ? This helps tremendously to optimize those things. Here is another user synthetic script: https://github.com/Lightning-AI/litdata/issues/62#issuecomment-1984466632 as a reference.

tikboaHIT commented 8 months ago

It should appear where you run the command. Maybe reduce the batch size and number of workers.

Could you provide a synthetic example for me to debug it too ? This helps tremendously to optimize those things. Here is another user synthetic script: #62 (comment) as a reference.

Sure:

import torch
from tqdm import tqdm
from litdata import optimize, StreamingDataset

def generate_images(video_path):
    data = {
        "name": video_path,
        "image": torch.rand((3, 8, 320, 568)),
    }
    return data

optimize(
    fn=generate_images,
    inputs=list(range(100)),
    output_dir="/root/data/example_data/chunk_cache",
    num_workers=1,
    chunk_bytes="256MB",
)

input_dir = '/root/data/example_data/chunk_cache'
dataset = StreamingDataset(input_dir, shuffle=True)
for data in tqdm(dataset):
    pass

input_dir = 's3://pzhao/data/example_data/chunk_cache'
dataset = StreamingDataset(input_dir, shuffle=True)
for data in tqdm(dataset):
    pass

The speed is as follows when I load from local.

Screenshot 2024-03-08 at 19 48 10

The speed is as follows when I load from s3.

Screenshot 2024-03-08 at 19 48 29
tchaton commented 8 months ago

Hey @tikboaHIT. Thanks. Are you streaming from s3 to your local machine ? If yes, you might also get bottlenecked by our own internet connection.

Addtionally, this is super un-optimized. You are storing the full video as raw tensor. This is usually a 10-100x compared to JPEG encoding, 1000x times compared to av1 format.

def generate_images(video_path):
    data = {
        "name": video_path,
        "image": torch.rand((3, 8, 320, 568)),
    }
    return data
tchaton commented 8 months ago

Hey @tikboaHIT, I run the exact same code on Lightning AI A10G.

import os
import torch
from tqdm import tqdm
from litdata import optimize, StreamingDataset, StreamingDataLoader

input_dir = '/teamspace/datasets/videos3'
dataset = StreamingDataset(input_dir, shuffle=True)
dataloader = StreamingDataLoader(dataset, batch_size=1, num_workers=1)#os.cpu_count())
for data in tqdm(dataset):
    pass

It took 21 seconds for me. So it is definitely your internet connection. Usually, streaming dataset is much faster when streaming in the cloud provider where the data are stored.

100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [00:21<00:00,  4.59it/s]
tchaton commented 8 months ago

A more efficient one would be to encode the images in JPEG as follow:

import os
import torch
from tqdm import tqdm
from litdata import optimize, StreamingDataset
from PIL import Image
from io import BytesIO

def generate_images(video_path):
    images = []
    for _ in range(8):
        random_image = torch.randint(0, 255, (320, 568, 3), dtype=torch.uint8).numpy()
        buff = BytesIO()
        Image.fromarray(random_image).save(buff, quality=90, format='JPEG') # You can implement a better resizing logic
        buff.seek(0)
        img = buff.read()
        images.append(Image.open(BytesIO(img)))
    return {
        "name": video_path,
        "image": images,
    }

optimize(
    fn=generate_images,
    inputs=list(range(100)),
    output_dir="/teamspace/datasets/videos5",
    num_workers=1,#,os.cpu_count(),
    chunk_bytes="64MB",
)

When streaming it from the cloud, it takes 1 seconds now.

100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [00:01<00:00, 50.08it/s]

Additionally, I recommend using torchvision.transforms.v2 which are roughly 40% faster at resizing the images, etc..

But alternatively, we support videos from torchvision video support: https://pytorch.org/audio/stable/build.ffmpeg.html. If you convert your clips into av1 format, they should get super small. You should be able to stream them easily and de-serialize them faster. Worth exploring.

tchaton commented 8 months ago

@tikboaHIT I also fixed the chunk_bytes not being correct with the optimize operator.

tikboaHIT commented 8 months ago

@tchaton Thank you for your suggestions and the quick bug fix. Regarding the logic for saving frames, I can provide more context. I mainly extract and save frames through the following logic, storing them in tensor and PIL.Image formats. I've found that the loading speed is very slow.

cmd = f"./preprocess/data_preprocess/get_frames.sh {video_name} 8 {frame_save_dir}"
os.popen(cmd).read()

imgs = {}
for index, img_path in enumerate(glob(f"{frame_save_dir}/*.jpg")):
    img = Image.fromarray(np.array(Image.open(img_path)))
    imgs[index] = img

data = {
    "name":  video_name,
    "image": imgs,
}

This should be similar to the solution you mentioned, both involve storing images in JPEG format.

You mentioned:

for _ in range(8):
    random_image = torch.randint(0, 255, (320, 568, 3), dtype=torch.uint8).numpy()
    buff = BytesIO()
    Image.fromarray(random_image).save(buff, quality=90, format='JPEG') # You can implement a better resizing logic
    buff.seek(0)
    img = buff.read()
    images.append(Image.open(BytesIO(img)))
tchaton commented 8 months ago

Hey @tikboaHIT . It isn't fully equivalent. In my case, I add some compression by explicitly converting them to JPEG with quality 90. You can see the difference by checking how many chunks files are being created. In my case, I had only 1. With your approach, I had dozens.

Would you mind trying the code above? If you don't see any differences, it means you are bottlenecked by your internet connection. I would recommend to try out Lightning AI to get the full speed.

tikboaHIT commented 8 months ago

@tchaton Indeed, using JPEG compression can significantly improve reading speed, but overall, the process is not particularly smooth due to intermittent stuttering. Is this primarily limited by my AWS network?

The recording is as belows:

https://github.com/Lightning-AI/litdata/assets/39000220/408f78b8-7255-4183-8447-9d4bd4be1abe

tchaton commented 8 months ago

Hey @tikboaHIT. Yes, I suspect this is the bottleneck and you would have the same not matter which library you are using.

But they might be some extra optimization to do for low bandwidth computer. Would you be free for a pair debugging session this week ?

Best, T.C

tikboaHIT commented 8 months ago

Hey @tchaton, that sounds great. Let's continue our conversation on Discord and find a time that works for both of us to schedule a meeting.

tikboaHIT commented 8 months ago

Hey @tchaton, When I use litdata to load data, I encounter this issue suddenly when the training reaches around 80% of the first epoch.

Screenshot 2024-03-12 at 20 58 08
tchaton commented 8 months ago

Hey @tikboaHIT, are you using the latest version of litdata ? I think I was resolved this bug on main. Otherwise, would you mind sharing a reproducing script ?

tikboaHIT commented 8 months ago

@tchaton Yes, version 2.2 is being used. Currently, the issue mainly occurs within a subset of the data. To provide some context, I have divided several thousand videos into 10 parts, with each part forming individual chunks according to the previous frame extraction logic. There were no errors in the first part of the data, but this problem arose in the second part. Due to privacy policy reasons, the reproducing script involves some data, which is not convenient to share.

Is there a way for me to debug on my own to find out the specific cause?

Additionally, there is another minor issue when I combine litdata with PyTorch Lightning. When the data volume is small [during the debugging process], training proceeds completely normally. However, when the data volume reaches around 500k, and I follow a training logic of train->validation->train, the training of the second epoch gets perpetually blocked. Then, this issue occurs: "watchdog caught collective operation timeout: WorkNCCL(SeqNum=27131, OpType=BROADCAST, NumelIn=274, NumelOut=274, Timeout(ms)=1800000) ran for 1800324 milliseconds before timing out." However, this issue does not arise if I do not perform validation in between.

The partial code snippet is shown below.

Screenshot 2024-03-13 at 23 20 03
tchaton commented 8 months ago

Hey @tikboaHIT,

Due to privacy policy reasons, the reproducing script involves some data, which is not convenient to share.

Do you think you could try to reproduce the bug with synthetic generated data or even an open source dataset? So I can debug it on my end.

train->validation->train

Are you using DDP ? Normally, each ranks should get the same quantity of data but it is possible there is a bug somewhere. If the length were to be different, then it would hang.

tikboaHIT commented 8 months ago

Hey @tchaton I've pinpointed that the issue lies with a particular piece of data within a chunk, as shown in the normal reading below.

Screenshot 2024-03-18 at 21 58 24

The abnormal data reading is as follows

Screenshot 2024-03-18 at 21 57 37

I'm not sure if this can be of help with your debugging.

tchaton commented 7 months ago

Hey @tikboaHIT. This means the chunk wasn't fully copied over when opened.