iejMac / video2dataset

Easily create large video dataset from video urls
MIT License
490 stars 57 forks source link

Example on how to handle a padded video frame in downstream network? #315

Open rohit901 opened 4 months ago

rohit901 commented 4 months ago

Hello,

I'm using WebVid10M dataset, and passing the following decoder args:

decoder_kwargs = {
    "n_frames": 16,  # get 16 frames from each video
    "fps": 8,
    "num_threads": 12,  # use 12 threads to decode the video
}

Since I'm passing 16 frames, I often get errors saying ValueError: video clip not long enough for decoding since there maybe some shorter videos in the dataset.

Thus, if I pass in pad_frames = True in above decoder_kwargs, I believe shorter clips would be padded up. Could you give some example code/snippet showing what would be the optimal way to utilize this padded batch of data when processing with our neural nets? Like do we get access to the start of padding, or the pad token? Do we need to use some kind of masking in our forward pass to ignore the padded tokens or just leave the padded data as it is? I was planning on using UNet3D network from diffusers to process the data.

rohit901 commented 4 months ago

Also, I'm still getting errors with the current code even after passing pad_frames = True Could you please help?

Error logs:

IndexError: Caught IndexError in DataLoader worker process 1.
Original Traceback (most recent call last):
....
  File "/home/rohit.bharadwaj/.conda/envs/LCM/lib/python3.11/site-packages/webdataset/handlers.py", line 24, in reraise_exception
    raise exn
  File "/home/rohit.bharadwaj/packages/vid2dataset/video2dataset/dataloader/custom_wds.py", line 232, in _return_always_map
    result = f(sample)
             ^^^^^^^^^
  File "/home/rohit.bharadwaj/.conda/envs/LCM/lib/python3.11/site-packages/webdataset/autodecode.py", line 562, in __call__
    return self.decode(sample)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/rohit.bharadwaj/packages/vid2dataset/video2dataset/dataloader/custom_wds.py", line 144, in decode
    result[k] = self.decode1(k, v)
                ^^^^^^^^^^^^^^^^^^
  File "/home/rohit.bharadwaj/.conda/envs/LCM/lib/python3.11/site-packages/webdataset/autodecode.py", line 518, in decode1
    result = f(key, data)
             ^^^^^^^^^^^^
  File "/home/rohit.bharadwaj/packages/vid2dataset/video2dataset/dataloader/video_decode.py", line 165, in __call__
    frames, start_frame, pad_start = self.get_frames(reader, n_frames, stride, scene_list=scene_list)
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rohit.bharadwaj/packages/vid2dataset/video2dataset/dataloader/video_decode.py", line 112, in get_frames
    frames = reader.get_batch(np.arange(frame_start, frame_start + n_frames * stride, stride).tolist())
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rohit.bharadwaj/.conda/envs/LCM/lib/python3.11/site-packages/decord/video_reader.py", line 174, in get_batch
    indices = _nd.array(self._validate_indices(indices))
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rohit.bharadwaj/.conda/envs/LCM/lib/python3.11/site-packages/decord/video_reader.py", line 132, in _validate_indices
    raise IndexError('Out of bound indices: {}'.format(indices[indices >= self._num_frame]))
IndexError: Out of bound indices: [ 60  64  68  72  76  80  84  88  92  96 100 104 108 112 116 120 124 128
 132 136 140 144 148 152 156 160 164 168 172 176 180 184 188 192 196 200
 204 208 212 216 220 224 228]

My code:

# WebVid validation split
SHARDS = "<path_to_vid>/WebVid-10M/val-data-wds/{00000..00004}.tar"

decoder_kwargs = {
    "n_frames": 16,  # get 16 frames from each video
    "fps": 8,
    "num_threads": 12,  # use 12 threads to decode the video
    "pad_frames": True # pad the video if it has less than 16 frames
}
resize_size = crop_size = (448,256)
batch_size = 32

dset = get_video_dataset(
    urls=SHARDS,
    batch_size=batch_size,
    decoder_kwargs=decoder_kwargs,
    resize_size=resize_size,
    crop_size=crop_size,
)

num_workers = 6  # 6 dataloader workers

dl = WebLoader(dset, batch_size=None, num_workers=num_workers)

for sample in dl:
    video_batch = sample["mp4"]
    print(video_batch.shape)  # torch.Size([32, 8, 256, 256, 3])

    # TODO: need to add option for text/metadata preprocessing (tokenization etc.)
    text_batch = sample["txt"]
    print(text_batch[0])
    metadata_batch = sample["json"]