THUDM / CogVideo

text and image to video generation: CogVideoX (2024) and CogVideo (ICLR 2023)
Apache License 2.0
9.08k stars 857 forks source link

You simply cache all data in memory, it does not support larger dataset #127

Closed StarCycle closed 2 months ago

StarCycle commented 3 months ago

Feature request / 功能建议

It seems that you simply cache all data in the CPU memory. If I try to finetune the model on a larger dataset, I will get an OOM error

https://github.com/THUDM/CogVideo/blob/edc382c74f7f34a8e9a6b403ff03cf9fd8f728aa/sat/data_video.py#L425

Motivation / 动机

Support finetuning on a larger dataset

Your contribution / 您的贡献

Is it possible to decode videos here:

https://github.com/THUDM/CogVideo/blob/edc382c74f7f34a8e9a6b403ff03cf9fd8f728aa/sat/data_video.py#L437

If so, I can drop a PR

zRzRzRzRzRzRzR commented 2 months ago

get, can you help make a PR, this is a issue and need to fix

trouble-maker007 commented 2 months ago

I think that is easy to fix

class SFTDataset(Dataset):
    def __init__(self, data_dir, video_size, fps, max_num_frames, skip_frms_num=3):
        """
        skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
        """
        super(SFTDataset, self).__init__()

        self.data_dir = data_dir
        self.video_size = video_size
        self.fps = fps
        self.max_num_frames = max_num_frames
        self.skip_frms_num = skip_frms_num

        self.video_paths = []
        self.caption_paths = []

        decord.bridge.set_bridge("torch")
        for root, dirnames, filenames in os.walk(data_dir):
            for filename in filenames:
                if filename.endswith(".mp4"):
                    video_path = os.path.join(root, filename)
                    caption_path = os.path.join(root, filename.replace(".mp4", ".txt")).replace("videos", "labels")
                    self.video_paths.append(video_path)
                    self.caption_paths.append(caption_path)

    def __getitem__(self, index):
        video_path = self.video_paths[index]
        caption_path = self.caption_paths[index]

        vr = VideoReader(uri=video_path, height=-1, width=-1)
        actual_fps = vr.get_avg_fps()
        ori_vlen = len(vr)

        if ori_vlen / actual_fps * self.fps > self.max_num_frames:
            num_frames = self.max_num_frames
            start = int(self.skip_frms_num)
            end = int(start + num_frames / self.fps * actual_fps)
            end_safty = min(int(start + num_frames / self.fps * actual_fps), int(ori_vlen))
            indices = np.arange(start, end, (end - start) // num_frames).astype(int)
            temp_frms = vr.get_batch(np.arange(start, end_safty))
            assert temp_frms is not None
            tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
            tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
        else:
            if ori_vlen > self.max_num_frames:
                num_frames = self.max_num_frames
                start = int(self.skip_frms_num)
                end = int(ori_vlen - self.skip_frms_num)
                indices = np.arange(start, end, (end - start) // num_frames).astype(int)
                temp_frms = vr.get_batch(np.arange(start, end))
                assert temp_frms is not None
                tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
                tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
            else:
                def nearest_smaller_4k_plus_1(n):
                    remainder = n % 4
                    if remainder == 0:
                        return n - 3
                    else:
                        return n - remainder + 1

                start = int(self.skip_frms_num)
                end = int(ori_vlen - self.skip_frms_num)
                num_frames = nearest_smaller_4k_plus_1(end - start)  # 3D VAE requires the number of frames to be 4k+1
                end = int(start + num_frames)
                temp_frms = vr.get_batch(np.arange(start, end))
                assert temp_frms is not None
                tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms

        tensor_frms = pad_last_frame(tensor_frms, num_frames)  # the len of indices may be less than num_frames, due to round error
        tensor_frms = tensor_frms.permute(0, 3, 1, 2)  # [T, H, W, C] -> [T, C, H, W]
        tensor_frms = resize_for_rectangle_crop(tensor_frms, self.video_size, reshape_mode="center")
        tensor_frms = (tensor_frms - 127.5) / 127.5

        # caption
        if os.path.exists(caption_path):
            caption = open(caption_path, "r").read().splitlines()[0]
        else:
            caption = ""

        item = {
            "mp4": tensor_frms,
            "txt": caption,
            "num_frames": num_frames,
            "fps": self.fps,
        }
        return item

    def __len__(self):
        return len(self.video_paths)

    @classmethod
    def create_dataset_function(cls, path, args, **kwargs):
        return cls(data_dir=path, **kwargs)
bertjiazheng commented 2 months ago

I have implemented it to make it load video and caption online. You can check the implementation here.

In addition, I load the 'video_path' and 'caption' from a CSV file instead of parsing the video and caption from the given path.

Yuancheng-Xu commented 1 month ago

Is it also an issue in the diffuser fine-tuning code?