iPERDance / iPERCore

Liquid Warping GAN with Attention: A Unified Framework for Human Image Synthesis
https://iperdance.github.io/work/impersonator-plus-plus.html
Apache License 2.0
2.42k stars 311 forks source link

ProcessedVideoDataset have no backround #122

Open orydatadudes opened 3 years ago

orydatadudes commented 3 years ago

ProcessedVideoDataset return the following output: sample = { "images": images, "smpls": smpls, "masks": masks } so "bg" , "offsets" ,"links_ids" are not include but during the training(evaluate) (lwg_trainer.py, line 643) "bg" is needed

    with torch.no_grad():
        images = inputs["images"].to(device, non_blocking=True)
        bg = inputs["bg"].to(device, non_blocking=True)
        smpls = inputs["smpls"].to(device, non_blocking=True)
        masks = inputs["masks"].to(device, non_blocking=True)
        offsets = inputs["offsets"].to(device, non_blocking=True)
        links_ids = inputs["links_ids"].to(device, non_blocking=True) if "links_ids" in inputs else what should i do 

thank you

Chalet37 commented 3 years ago

ProcessedVideoDataset return the following output: sample = { "images": images, "smpls": smpls, "masks": masks } so "bg" , "offsets" ,"links_ids" are not include but during the training(evaluate) (lwg_trainer.py, line 643) "bg" is needed

    with torch.no_grad():
        images = inputs["images"].to(device, non_blocking=True)
        bg = inputs["bg"].to(device, non_blocking=True)
        smpls = inputs["smpls"].to(device, non_blocking=True)
        masks = inputs["masks"].to(device, non_blocking=True)
        offsets = inputs["offsets"].to(device, non_blocking=True)
        links_ids = inputs["links_ids"].to(device, non_blocking=True) if "links_ids" in inputs else what should i do 

thank you

have u found any solution?

xfguo-ucas commented 1 year ago

I change the dataset load code( iPERCore/data/processed_video_dataset.py ) and it works:

def _load_pairs(self, vid_info):
    ns = self._opt.num_source

    length = vid_info["length"]
    ft_ids = vid_info["ft_ids"]

    replace = ns >= len(ft_ids)
    src_ids = list(np.random.choice(ft_ids, ns, replace=replace))
    src_ids[0] = ft_ids[0]

    tsf_ids = list(np.random.choice(length, self._opt.time_step, replace=False))
    tsf_ids.sort()

    # take the source and target ids
    pair_ids = src_ids + tsf_ids
    smpls = vid_info["smpls"][pair_ids]

    images = []
    masks = []
    image_dir = vid_info["img_dir"]
    images_names = vid_info["images"]
    alphas_paths = vid_info["alpha_paths"]
    pseudo_bgs = []
    offsets = vid_info["offsets"]
    bg_dir = vid_info["bg_dir"]

    for t in pair_ids:
        image_path = os.path.join(image_dir, images_names[t])
        image = cv_utils.read_cv2_img(image_path)

        images.append(image)

        mask = cv_utils.read_mask(alphas_paths[t], self._opt.image_size)

        # front is 0, and background is 1
        mask = 1.0 - mask
        masks.append(mask)

    bg_img_paths = []
    for s_id in src_ids:
        name = images_names[s_id]
        bg_name = name.split(".")[0] + "_replaced.png"
        bg_path = os.path.join(bg_dir, bg_name)
        bg_img_paths.append(bg_path)

    for bg_path in bg_img_paths:
        bg_img = cv_utils.read_cv2_img(bg_path)
        bg_img = cv_utils.normalize_img(bg_img, image_size=self._opt.image_size, transpose=True)
        pseudo_bgs.append(bg_img)
    pseudo_bgs = np.stack(pseudo_bgs)

    return images, smpls, masks, offsets, pseudo_bgs

def __getitem__(self, index):
    """

    Args:
        index (int): the sample index of self._dataset_size.

    Returns:
        sample (dict): the data sample, it contains the following informations:
            --images (torch.Tensor): (ns + nt, 3, h, w), here `ns` and `nt` are the number of source and targets;
            --masks (torch.Tensor): (ns + nt, 1, h, w);
            --smpls (torch.Tensor): (ns + nt, 85);

    """

    vid_info = self._vids_info[index % self._num_videos]

    images, smpls, masks, offsets, pseudo_bgs = self._load_pairs(vid_info)

    # pack data
    sample = {
        "images": images,
        "smpls": smpls,
        "masks": masks,
        "offsets": offsets,
        "bg":pseudo_bgs
    }

    sample = self._transform(sample)

    return sample