ControlNet / MARLIN

[CVPR] MARLIN: Masked Autoencoder for facial video Representation LearnINg
https://openaccess.thecvf.com/content/CVPR2023/html/Cai_MARLIN_Masked_Autoencoder_for_Facial_Video_Representation_LearnINg_CVPR_2023_paper
Other
231 stars 20 forks source link

bugfix:boundary error in _load_video #8

Closed Nikki-Gu closed 1 year ago

Nikki-Gu commented 1 year ago

(1) There is a boundary error in marlin_pytorch.marlin.py line 160. When the total_frames=self.clip_frames * sample_rate(in the default setting, it is total_frames=32, the program would turn into line 167 extracting features based on a sliding window, resulting in errors. (2) ffmpeg.probe gets more total frames than read_video implemented by torchvision.io.read_video, resulting in errors in the next condition choosing.

ControlNet commented 1 year ago

Hi, thank you for proposing the pull request.

For problem (1), I think changing the total_frames < self.clip_frames * sample_rate to total_frames <= self.clip_frames * sample_rate should solve the problem.

For problem (2), the main concern is, if the input video is several GB with very long time, for example, more than 1 hour, using torchvision.io.read_video will load all the video RGB values into memory causing OOM. Using ffprobe can read the video metadata without reading them into memory, to avoid bottleneck in both time cost and memory size. Please suggest if you have any good way to get the precise video frame numbers without loading it. How about cv2.VideoCapture(video_name).get(cv2.CAP_PROP_FRAME_COUNT) from opencv?

Nikki-Gu commented 1 year ago

Thank you for replying! I have tried the cv2.VideoCapture(video_name).get(cv2.CAP_PROP_FRAME_COUNT) from opencv, however, it gets the same result as ffprobe. The problem may lie in torchvision.io.read_video. I found this issue have been discussed in https://github.com/pytorch/vision/issues/2490. However, it is too hard to read for me as a beginer of video processing :( Maybe you can read the above issue page and find a good solution.

From my point of view, considering the time cost and memory size, how about still using ffprobe while double check the number of frames in condition total_frames <= self.clip_frames * sample_rate

    def _load_video(self, video_path: str, sample_rate: int, stride: int) -> Generator[Tensor, None, None]:
        probe = ffmpeg.probe(video_path)
        total_frames = int(probe["streams"][0]["nb_frames"])
        if total_frames <= self.clip_frames:
            video = read_video(video_path, channel_first=True) / 255  # (T, C, H, W)
            # pad frames to 16
            v = padding_video(video, self.clip_frames, "same")  # (T, C, H, W)
            assert v.shape[0] == self.clip_frames
            yield v.permute(1, 0, 2, 3).unsqueeze(0).to(self.device)
        elif total_frames <= self.clip_frames * sample_rate:
            video = read_video(video_path, channel_first=True) / 255  # (T, C, H, W)
            # use first 16 frames
            if video.shape[0] < self.clip_frames: 
                # double check the number of frames, see https://github.com/pytorch/vision/issues/2490 for more information
                v = padding_video(video, self.clip_frames, "same")  # (T, C, H, W)
            v = video[:self.clip_frames]
            yield v.permute(1, 0, 2, 3).unsqueeze(0).to(self.device)
        ......
ControlNet commented 1 year ago

This sounds good. Could you please commit these changes in the pr?

ControlNet commented 1 year ago

The fix is included in the library version 0.3.2