ZiqiaoPeng / SyncTalk

[CVPR 2024] This is the official source for our paper "SyncTalk: The Devil is in the Synchronization for Talking Head Synthesis"
https://ziqiaopeng.github.io/synctalk/
Other
1.31k stars 157 forks source link

这里的代码,太占显存了,如何优化?显存瞬间爆到20g+ #165

Open einsqing opened 4 months ago

einsqing commented 4 months ago
    face_img = cv2.imread(os.path.join(ori_imgs_dir, '{:d}.jpg'.format(ref_id)))
    face_img_mask = cv2.imread(os.path.join(mask_dir, '{:d}.png'.format(ref_id)))

    rigid_mask = face_img_mask[..., 0] > 250
    rigid_num = np.sum(rigid_mask)
    flow_frame_num = 2500
    flow_frame_num = min(flow_frame_num, valid_img_num)
    rigid_flow = np.zeros((flow_frame_num, 2, rigid_num), np.float32)
    for i in range(flow_frame_num):
        flow = np.load(os.path.join(flow_dir, '{:d}_{:d}.npy'.format(ref_id, valid_img_ids[i])))
        rigid_flow[i] = flow[:, rigid_mask]
    rigid_flow = rigid_flow.transpose((2, 1, 0))
    rigid_flow = torch.as_tensor(rigid_flow).cuda()
    lap_kernel = torch.Tensor(
        (-0.5, 1.0, -0.5)).unsqueeze(0).unsqueeze(0).float().cuda()
    flow_lap = F.conv1d(
        rigid_flow.reshape(-1, 1, rigid_flow.shape[-1]), lap_kernel)
    flow_lap = flow_lap.view(rigid_flow.shape[0], 2, -1)
    flow_lap = torch.norm(flow_lap, dim=1)
    valid_frame = torch.mean(flow_lap, dim=0) < (torch.mean(flow_lap) * 3)
    flow_lap = flow_lap[:, valid_frame]
    rigid_flow_mean = torch.mean(flow_lap, dim=1)
    rigid_flow_show = (rigid_flow_mean - torch.min(rigid_flow_mean)) / \
                      (torch.max(rigid_flow_mean) - torch.min(rigid_flow_mean)) * 255
    rigid_flow_show = rigid_flow_show.byte().cpu().numpy()
    rigid_flow_img = np.zeros((h, w, 1), dtype=np.uint8)
    rigid_flow_img[...] = 255
    rigid_flow_img[rigid_mask, 0] = rigid_flow_show
zhjygit commented 4 months ago

占用显存这么高吗?我只有两块8G的显存。

bupojie888 commented 4 months ago

我电脑占用显存这么高?我把两块8G的显存都占了。

我是2060s 8G 从来没报过显存啊

einsqing commented 4 months ago

我电脑占用显存这么高?我把两块8G的显存都占了。

我是2060s 8G 从来没报过显存啊

你看下这个代码执行的时候,显存占用。另外你 512 视频的码率是多少?