Closed StarCycle closed 2 months ago
get, can you help make a PR, this is a issue and need to fix
I think that is easy to fix
class SFTDataset(Dataset):
def __init__(self, data_dir, video_size, fps, max_num_frames, skip_frms_num=3):
"""
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
"""
super(SFTDataset, self).__init__()
self.data_dir = data_dir
self.video_size = video_size
self.fps = fps
self.max_num_frames = max_num_frames
self.skip_frms_num = skip_frms_num
self.video_paths = []
self.caption_paths = []
decord.bridge.set_bridge("torch")
for root, dirnames, filenames in os.walk(data_dir):
for filename in filenames:
if filename.endswith(".mp4"):
video_path = os.path.join(root, filename)
caption_path = os.path.join(root, filename.replace(".mp4", ".txt")).replace("videos", "labels")
self.video_paths.append(video_path)
self.caption_paths.append(caption_path)
def __getitem__(self, index):
video_path = self.video_paths[index]
caption_path = self.caption_paths[index]
vr = VideoReader(uri=video_path, height=-1, width=-1)
actual_fps = vr.get_avg_fps()
ori_vlen = len(vr)
if ori_vlen / actual_fps * self.fps > self.max_num_frames:
num_frames = self.max_num_frames
start = int(self.skip_frms_num)
end = int(start + num_frames / self.fps * actual_fps)
end_safty = min(int(start + num_frames / self.fps * actual_fps), int(ori_vlen))
indices = np.arange(start, end, (end - start) // num_frames).astype(int)
temp_frms = vr.get_batch(np.arange(start, end_safty))
assert temp_frms is not None
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
else:
if ori_vlen > self.max_num_frames:
num_frames = self.max_num_frames
start = int(self.skip_frms_num)
end = int(ori_vlen - self.skip_frms_num)
indices = np.arange(start, end, (end - start) // num_frames).astype(int)
temp_frms = vr.get_batch(np.arange(start, end))
assert temp_frms is not None
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
else:
def nearest_smaller_4k_plus_1(n):
remainder = n % 4
if remainder == 0:
return n - 3
else:
return n - remainder + 1
start = int(self.skip_frms_num)
end = int(ori_vlen - self.skip_frms_num)
num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1
end = int(start + num_frames)
temp_frms = vr.get_batch(np.arange(start, end))
assert temp_frms is not None
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
tensor_frms = pad_last_frame(tensor_frms, num_frames) # the len of indices may be less than num_frames, due to round error
tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W]
tensor_frms = resize_for_rectangle_crop(tensor_frms, self.video_size, reshape_mode="center")
tensor_frms = (tensor_frms - 127.5) / 127.5
# caption
if os.path.exists(caption_path):
caption = open(caption_path, "r").read().splitlines()[0]
else:
caption = ""
item = {
"mp4": tensor_frms,
"txt": caption,
"num_frames": num_frames,
"fps": self.fps,
}
return item
def __len__(self):
return len(self.video_paths)
@classmethod
def create_dataset_function(cls, path, args, **kwargs):
return cls(data_dir=path, **kwargs)
I have implemented it to make it load video and caption online. You can check the implementation here.
In addition, I load the 'video_path' and 'caption' from a CSV file instead of parsing the video and caption from the given path.
Is it also an issue in the diffuser fine-tuning code?
Feature request / 功能建议
It seems that you simply cache all data in the CPU memory. If I try to finetune the model on a larger dataset, I will get an OOM error
https://github.com/THUDM/CogVideo/blob/edc382c74f7f34a8e9a6b403ff03cf9fd8f728aa/sat/data_video.py#L425
Motivation / 动机
Support finetuning on a larger dataset
Your contribution / 您的贡献
Is it possible to decode videos here:
https://github.com/THUDM/CogVideo/blob/edc382c74f7f34a8e9a6b403ff03cf9fd8f728aa/sat/data_video.py#L437
If so, I can drop a PR