rese1f / MovieChat

[CVPR 2024] MovieChat: From Dense Token to Sparse Memory for Long Video Understanding
https://rese1f.github.io/MovieChat/
BSD 3-Clause "New" or "Revised" License
531 stars 41 forks source link

Processing flow of MovieChat-1K_train #54

Open LZHgrla opened 6 months ago

LZHgrla commented 6 months ago

Hi! Could you provide the processing script or procedure for MovieChat-1K_train dataset? We plan to fine-tune our model on this dataset and need to ensure that pre-training phase follows the same processing procedure.

Espere-1119-Song commented 6 months ago

For each video in MovieChat-1K_train dataset, we average sample 8192 frames with eva_clip_g, set the image_size to 224 and store in hdf5. Our feature extraction data is as follow:

import os
import cv2
import numpy as np
import torchvision.transforms as transforms
import torch
import einops
import h5py
from MovieChat.models.eva_vit import create_eva_vit_g

device = "cuda:0"

input_folder = 'our_train_data'

output_folder = 'feature_hdf5'

if not os.path.exists(output_folder):
    os.makedirs(output_folder)

subfolders = [f.name for f in os.scandir(output_folder) if f.is_dir()]
mp4_files = [f for f in os.listdir(input_folder) if f.endswith('.mp4')]

frames_to_read = 8192

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def init_vision_encoder(
    img_size=224, drop_path_rate=0, use_grad_checkpoint=False, precision="fp16"
):
    model_name = "eva_clip_g"
    visual_encoder = create_eva_vit_g(
        img_size, drop_path_rate, use_grad_checkpoint, precision
    ).float()
    visual_encoder.eval()

    return visual_encoder

image_encoder = init_vision_encoder().to(device)

count = 0
for mp4_file in mp4_files:
    if mp4_file.split('.')[0] not in subfolders:
        try:
            video_path = os.path.join(input_folder, mp4_file)

            cap = cv2.VideoCapture(video_path)

            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            frame_interval = total_frames // frames_to_read

            features_list = []

            frame_count = 0
            batch_size = 64
            piece_count = 0
            current_batch = []
            while frame_count < total_frames:
                ret, frame = cap.read()

                if not ret:
                    break

                if frame_count % frame_interval == 0:
                    frame_tensor = transform(frame)
                    frame_tensor = frame_tensor.unsqueeze(0).to(device)
                    current_batch.append(frame_tensor)

                    if len(current_batch) == batch_size:
                        batch = torch.cat(current_batch, dim=0).to(device)
                        with torch.no_grad():
                            features = image_encoder(batch).cpu()
                            output_dict = os.path.join(output_folder, os.path.splitext(mp4_file)[0])
                            if not os.path.exists(output_dict):
                                os.makedirs(output_dict)
                            output_filename = str(piece_count) + '.h5'
                            output_path = os.path.join(output_dict, output_filename)
                            with h5py.File(output_path, "w") as hdf5_file:
                                dataset_name = f"frames_{piece_count}"
                                print(dataset_name)
                                hdf5_file.create_dataset(dataset_name, data=features)
                            print(output_filename)
                            piece_count += 1
                        current_batch = []

                frame_count += 1

            if len(current_batch) > 0:
                batch = torch.cat(current_batch, dim=0).to(device)
                with torch.no_grad():
                    features = image_encoder(batch).cpu()
                    output_dict = os.path.join(output_folder, os.path.splitext(mp4_file)[0])
                    if not os.path.exists(output_dict):
                        os.makedirs(output_dict)
                    output_filename = str(piece_count) + '.h5'
                    output_path = os.path.join(output_dict, output_filename)
                    with h5py.File(output_path, "w") as hdf5_file:
                        dataset_name = f"frames_{piece_count}"
                        print(dataset_name)
                        hdf5_file.create_dataset(dataset_name, data=features)
                    print(output_filename)
                    piece_count += 1

            cap.release()

        except Exception as e:
            print(e)

However, we didn't use the extracted feature to run MovieChat. I think the main difference is about frame reading in inference.py and frame encoding in moviechat.py.

Hope this can be helpful to you! :)

LZHgrla commented 6 months ago

@Espere-1119-Song Awesome! Thanks very much!

HIT-cwh commented 6 months ago

Hi @Espere-1119-Song ! I'm uncertain if I'm grasping this accurately. In the provided code snippet, the video frames obtained through cv2.VideoCapture are in BGR format, whereas the images passed into transforms.ToPILImage() should adhere to the RGB format, leading to potential inconsistency.

Espere-1119-Song commented 6 months ago

Thank you for pointing out the issue! We apologize for any inconvenience caused. We are currently uploading the raw videos to Huggingface, and we expect to complete this by the weekend.

HIT-cwh commented 6 months ago

Thanks

Espere-1119-Song commented 6 months ago

We upload the raw videos of the training set :)