tumurzakov / AnimateDiff

AnimationDiff with train
Apache License 2.0
115 stars 26 forks source link

How to train on multiple video dataset? #3

Closed howardgriffin closed 1 year ago

howardgriffin commented 1 year ago

Hi, the demo showed in your notebook only handles one video, but how to train on multiple video datasets?

tumurzakov commented 1 year ago

To train multiple videos TuneAVideoDataset class should be modified to take an array of prompts and array of video pathes.

Something like


import decord
decord.bridge.set_bridge('torch')

from torch.utils.data import Dataset
from einops import rearrange

class MultiTuneAVideoDataset(Dataset):
    def __init__(
            self,
            video_path: list[str],
            prompt: list[str],
            width: int = 512,
            height: int = 512,
            n_sample_frames: int = 8,
            sample_start_idx: int = 0,
            sample_frame_rate: int = 1,
    ):
        self.video_path = video_path
        self.prompt = prompt
        self.prompt_ids = None

        self.width = width
        self.height = height
        self.n_sample_frames = n_sample_frames
        self.sample_start_idx = sample_start_idx
        self.sample_frame_rate = sample_frame_rate

    def __len__(self):
        return len(self.video_path)

    def __getitem__(self, index):
        # load and sample video frames
        vr = decord.VideoReader(self.video_path[index], width=self.width, height=self.height)
        sample_index = list(range(self.sample_start_idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames]
        video = vr.get_batch(sample_index)
        video = rearrange(video, "f h w c -> f c h w")

        example = {
            "pixel_values": (video / 127.5 - 1.0),
            "prompt_ids": self.prompt_ids
        }

        return example

then, replace this class in train.py