omerbt / TokenFlow

Official Pytorch Implementation for "TokenFlow: Consistent Diffusion Features for Consistent Video Editing" presenting "TokenFlow" (ICLR 2024)
https://diffusion-tokenflow.github.io
MIT License
1.52k stars 134 forks source link

What is the code of 'NN field compute & warp' ? #43

Open justinday123 opened 3 months ago

justinday123 commented 3 months ago

Hi, thank you for your nice work. Is the part of th code 'NN field compute & warp' is this code?

def prepare_depth_maps(self, model_type='DPT_Large', device='cuda'):
    depth_maps = []
    midas = torch.hub.load("intel-isl/MiDaS", model_type)
    midas.to(device)
    midas.eval()

    midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")

    if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
        transform = midas_transforms.dpt_transform
    else:
        transform = midas_transforms.small_transform

    for i in range(len(self.paths)):
        img = cv2.imread(self.paths[i])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        latent_h = img.shape[0] // 8
        latent_w = img.shape[1] // 8

        input_batch = transform(img).to(device)
        prediction = midas(input_batch)

        depth_map = torch.nn.functional.interpolate(
            prediction.unsqueeze(1),
            size=(latent_h, latent_w),
            mode="bicubic",
            align_corners=False,
        )
        depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
        depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
        depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
        depth_maps.append(depth_map)

    return torch.cat(depth_maps).to(self.device).to(torch.float16)