SysCV / shift-dev

SHIFT Dataset DevKit - CVPR2022
https://www.vis.xyz/shift
MIT License
101 stars 10 forks source link

unable to load continous video data #42

Closed monster119120 closed 1 year ago

monster119120 commented 1 year ago

Here is the code, and I want to load data from SHIFT_dataset/continuous, instead of SHIFT_dataset/discrete.

But this code always load discrete.

import os
import sys

import torch
from torch.utils.data import DataLoader

# Add the root directory of the project to the path. Remove the following two lines
# if you have installed shift_dev as a package.
# root_dir = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))
# sys.path.append(root_dir)

from shift_dev import SHIFTDataset
from shift_dev.types import Keys
from shift_dev.utils.backend import ZipBackend, FileBackend

def main():
    """Load the SHIFT dataset and print the tensor shape of the first batch."""

    dataset = SHIFTDataset(
        data_root="../SHIFT_dataset/",
        split="val",
        keys_to_load=[
            Keys.images,
            # Keys.intrinsics,
            Keys.boxes2d,
            Keys.boxes2d_classes,
            # Keys.boxes2d_track_ids,
            # Keys.segmentation_masks,
        ],
        views_to_load=["front"],
        backend=FileBackend(),  # also supports HDF5Backend(), FileBackend()
        verbose=True,
        framerate="images",
        shift_type="continuous/1x",
    )

    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
    )

    # Print the dataset size
    print(f"Total number of samples: {len(dataset)}.")

    # Print the tensor shape of the first batch.
    print('\n')
    for i, batch in enumerate(dataloader):
        print(f"Batch {i}:")
        for k, data in batch["front"].items():
            if isinstance(data, torch.Tensor):
                print(f"{k}: {data.shape}")
            else:
                print(f"{k}: {data}")
        break

    # Print the sample indices within a video.
    # The video indices groups frames based on their video sequences. They are useful for training on videos.
    print('\n')
    video_to_indices = dataset.video_to_indices
    for video, indices in video_to_indices.items():
        print(f"Video name: {video}")
        print(f"Sample indices within a video: {indices}")
        break

if __name__ == "__main__":
    main()
suniique commented 1 year ago

Hey @monster119120, thanks very much for pointing that out! I confirmed that this is a bug in the example dataloader we provided. Now it is fixed in the recent PR. You can pull the code again and test it. Tell me if there's any other problem remaining!