minar09 / cp-vton-plus

Official implementation for "CP-VTON+: Clothing Shape and Texture Preserving Image-Based Virtual Try-On", CVPRW 2020
https://minar09.github.io/cpvtonplus/
MIT License
356 stars 122 forks source link

Getting RuntimeError: Given groups=1, weight of size [64, 22, 4, 4], expected input[1, 2 9, 256, 192] to have 22 channels, but got 29 channels instead error #85

Open lokesh0606 opened 2 years ago

lokesh0606 commented 2 years ago

I'm getting this runtime error

RuntimeError: Given groups=1, weight of size [64, 22, 4, 4], expected input[1, 2 9, 256, 192] to have 22 channels, but got 29 channels instead

I generated the images using the LIP_JPPNET as suggested here for image parse and image_parse_new. But even after doing that and providing the dataset as suggested I'm getting the above error. But when testing for the dataset provided by the vton it worked perfectly well without any errors

code in test.py below...........

coding=utf-8

import torch import torch.nn as nn import torch.nn.functional as F

import argparse import os import time from cp_dataset import CPDataset, CPDataLoader from networks import GMM, UnetGenerator, load_checkpoint

from tensorboardX import SummaryWriter from visualization import board_add_image, board_add_images, save_images

def get_opt(): parser = argparse.ArgumentParser()

parser.add_argument("--name", default="GMM")
# parser.add_argument("--name", default="TOM")

parser.add_argument("--gpu_ids", default="")
parser.add_argument('-j', '--workers', type=int, default=1)
parser.add_argument('-b', '--batch-size', type=int, default=4)

parser.add_argument("--dataroot", default="data")

# parser.add_argument("--datamode", default="train")
parser.add_argument("--datamode", default="test")

parser.add_argument("--stage", default="GMM")
# parser.add_argument("--stage", default="TOM")

# parser.add_argument("--data_list", default="train_pairs.txt")
parser.add_argument("--data_list", default="test_pairs.txt")
# parser.add_argument("--data_list", default="test_pairs_same.txt")

parser.add_argument("--fine_width", type=int, default=192)
parser.add_argument("--fine_height", type=int, default=256)
parser.add_argument("--radius", type=int, default=5)
parser.add_argument("--grid_size", type=int, default=5)

parser.add_argument('--tensorboard_dir', type=str,
                    default='tensorboard', help='save tensorboard infos')

parser.add_argument('--result_dir', type=str,
                    default='result', help='save result infos')

parser.add_argument('--checkpoint', type=str, default='checkpoints/GMM/gmm_final.pth', help='model checkpoint for test')
# parser.add_argument('--checkpoint', type=str, default='checkpoints/TOM/tom_final.pth', help='model checkpoint for test')

parser.add_argument("--display_count", type=int, default=1)
parser.add_argument("--shuffle", action='store_true',
                    help='shuffle input data')

opt = parser.parse_args()
return opt

def test_gmm(opt, test_loader, model, board): model model.eval()

base_name = os.path.basename(opt.checkpoint)
name = opt.name
save_dir = os.path.join(opt.result_dir, name, opt.datamode)
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
warp_cloth_dir = os.path.join(save_dir, 'warp-cloth')
if not os.path.exists(warp_cloth_dir):
    os.makedirs(warp_cloth_dir)
warp_mask_dir = os.path.join(save_dir, 'warp-mask')
if not os.path.exists(warp_mask_dir):
    os.makedirs(warp_mask_dir)
result_dir1 = os.path.join(save_dir, 'result_dir')
if not os.path.exists(result_dir1):
    os.makedirs(result_dir1)
overlayed_TPS_dir = os.path.join(save_dir, 'overlayed_TPS')
if not os.path.exists(overlayed_TPS_dir):
    os.makedirs(overlayed_TPS_dir)
warped_grid_dir = os.path.join(save_dir, 'warped_grid')
if not os.path.exists(warped_grid_dir):
    os.makedirs(warped_grid_dir)
for step, inputs in enumerate(test_loader.data_loader):
    iter_start_time = time.time()

    c_names = inputs['c_name']
    im_names = inputs['im_name']
    im = inputs['image']
    im_pose = inputs['pose_image']
    im_h = inputs['head']
    shape = inputs['shape']
    agnostic = inputs['agnostic']
    c = inputs['cloth']
    cm = inputs['cloth_mask']
    im_c = inputs['parse_cloth']
    im_g = inputs['grid_image']
    shape_ori = inputs['shape_ori']  # original body shape without blurring

    grid, theta = model(agnostic, cm)
    warped_cloth = F.grid_sample(c, grid, padding_mode='border')
    warped_mask = F.grid_sample(cm, grid, padding_mode='zeros')
    warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros')
    overlay = 0.7 * warped_cloth + 0.3 * im

    visuals = [[im_h, shape, im_pose],
               [c, warped_cloth, im_c],
               [warped_grid, (warped_cloth+im)*0.5, im]]

    # save_images(warped_cloth, c_names, warp_cloth_dir)
    # save_images(warped_mask*2-1, c_names, warp_mask_dir)
    save_images(warped_cloth, im_names, warp_cloth_dir)
    save_images(warped_mask * 2 - 1, im_names, warp_mask_dir)
    save_images(shape_ori * 0.2 + warped_cloth *
                0.8, im_names, result_dir1)
    save_images(warped_grid, im_names, warped_grid_dir)
    save_images(overlay, im_names, overlayed_TPS_dir)

    if (step+1) % opt.display_count == 0:
        board_add_images(board, 'combine', visuals, step+1)
        t = time.time() - iter_start_time
        print('step: %8d, time: %.3f' % (step+1, t), flush=True)

def test_tom(opt, test_loader, model, board): model model.eval()

base_name = os.path.basename(opt.checkpoint)
# save_dir = os.path.join(opt.result_dir, base_name, opt.datamode)
save_dir = os.path.join(opt.result_dir, opt.name, opt.datamode)
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
try_on_dir = os.path.join(save_dir, 'try-on')
if not os.path.exists(try_on_dir):
    os.makedirs(try_on_dir)
p_rendered_dir = os.path.join(save_dir, 'p_rendered')
if not os.path.exists(p_rendered_dir):
    os.makedirs(p_rendered_dir)
m_composite_dir = os.path.join(save_dir, 'm_composite')
if not os.path.exists(m_composite_dir):
    os.makedirs(m_composite_dir)
im_pose_dir = os.path.join(save_dir, 'im_pose')
if not os.path.exists(im_pose_dir):
    os.makedirs(im_pose_dir)
shape_dir = os.path.join(save_dir, 'shape')
if not os.path.exists(shape_dir):
    os.makedirs(shape_dir)
im_h_dir = os.path.join(save_dir, 'im_h')
if not os.path.exists(im_h_dir):
    os.makedirs(im_h_dir)  # for test data

print('Dataset size: %05d!' % (len(test_loader.dataset)), flush=True)
for step, inputs in enumerate(test_loader.data_loader):
    iter_start_time = time.time()

    im_names = inputs['im_name']
    im = inputs['image']
    im_pose = inputs['pose_image']
    im_h = inputs['head']
    shape = inputs['shape']

    agnostic = inputs['agnostic']
    c = inputs['cloth']
    cm = inputs['cloth_mask']

    # outputs = model(torch.cat([agnostic, c], 1))  # CP-VTON
    outputs = model(torch.cat([agnostic, c, cm], 1))  # CP-VTON+
    p_rendered, m_composite = torch.split(outputs, 3, 1)
    p_rendered = F.tanh(p_rendered)
    m_composite = F.sigmoid(m_composite)
    p_tryon = c * m_composite + p_rendered * (1 - m_composite)

    visuals = [[im_h, shape, im_pose],
               [c, 2*cm-1, m_composite],
               [p_rendered, p_tryon, im]]

    save_images(p_tryon, im_names, try_on_dir)
    save_images(im_h, im_names, im_h_dir)
    save_images(shape, im_names, shape_dir)
    save_images(im_pose, im_names, im_pose_dir)
    save_images(m_composite, im_names, m_composite_dir)
    save_images(p_rendered, im_names, p_rendered_dir)  # For test data

    if (step+1) % opt.display_count == 0:
        board_add_images(board, 'combine', visuals, step+1)
        t = time.time() - iter_start_time
        print('step: %8d, time: %.3f' % (step+1, t), flush=True)

def main(): opt = get_opt() print(opt) print("Start to test stage: %s, named: %s!" % (opt.stage, opt.name))

# create dataset
test_dataset = CPDataset(opt)
print("Test Dataset :",test_dataset)

# create dataloader
test_loader = CPDataLoader(opt, test_dataset)
print("Test Loader :", test_loader)
# visualization
if not os.path.exists(opt.tensorboard_dir):
    os.makedirs(opt.tensorboard_dir)
board = SummaryWriter(logdir=os.path.join(opt.tensorboard_dir, opt.name))

# create model & test
if opt.stage == 'GMM':
    model = GMM(opt)
    load_checkpoint(model, opt.checkpoint)
    with torch.no_grad():
        test_gmm(opt, test_loader, model, board)
elif opt.stage == 'TOM':
    # model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)  # CP-VTON
    model = UnetGenerator(26, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)  # CP-VTON+
    load_checkpoint(model, opt.checkpoint)
    with torch.no_grad():
        test_tom(opt, test_loader, model, board)
else:
    raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

print('Finished test %s, named: %s!' % (opt.stage, opt.name))

if name == "main": main()

code in cp_dataset.py

coding=utf-8

import torch import torch.utils.data as data import torchvision.transforms as transforms

from PIL import Image from PIL import ImageDraw

import os.path as osp import numpy as np

import pandas as pd

import json

class CPDataset(data.Dataset): """Dataset for CP-VTON+. """

def __init__(self, opt):
    super(CPDataset, self).__init__()
    # base setting
    self.opt = opt
    self.root = opt.dataroot
    self.datamode = opt.datamode  # train or test or self-defined
    self.stage = opt.stage  # GMM or TOM
    self.data_list = opt.data_list
    self.fine_height = opt.fine_height
    self.fine_width = opt.fine_width
    self.radius = opt.radius
    self.data_path = osp.join(opt.dataroot, opt.datamode)
    self.transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load data list
    im_names = []
    c_names = []
    with open(osp.join(opt.dataroot, opt.data_list), 'r') as f:
        for line in f.readlines():
            im_name, c_name = line.strip().split()
            im_names.append(im_name)
            c_names.append(c_name)

    self.im_names = im_names
    self.c_names = c_names

def name(self):
    return "CPDataset"

def __getitem__(self, index):
    c_name = self.c_names[index]
    im_name = self.im_names[index]
    if self.stage == 'GMM':
        c = Image.open(osp.join(self.data_path, 'cloth', c_name))
        cm = Image.open(osp.join(self.data_path, 'cloth-mask', c_name)).convert('L')
    else:
        c = Image.open(osp.join(self.data_path, 'warp-cloth', im_name))    # c_name, if that is used when saved
        cm = Image.open(osp.join(self.data_path, 'warp-mask', im_name)).convert('L')    # c_name, if that is used when saved

    c = self.transform(c)  # [-1,1]
    cm_array = np.array(cm)
    cm_array = (cm_array >= 128).astype(np.float32)
    cm = torch.from_numpy(cm_array)  # [0,1]
    cm.unsqueeze_(0)

    # person image
    im = Image.open(osp.join(self.data_path, 'image', im_name))
    im = self.transform(im)  # [-1,1]

    """
    LIP labels

    [(0, 0, 0),    # 0=Background
     (128, 0, 0),  # 1=Hat
     (255, 0, 0),  # 2=Hair
     (0, 85, 0),   # 3=Glove
     (170, 0, 51),  # 4=SunGlasses
     (255, 85, 0),  # 5=UpperClothes
     (0, 0, 85),     # 6=Dress
     (0, 119, 221),  # 7=Coat
     (85, 85, 0),    # 8=Socks
     (0, 85, 85),    # 9=Pants
     (85, 51, 0),    # 10=Jumpsuits
     (52, 86, 128),  # 11=Scarf
     (0, 128, 0),    # 12=Skirt
     (0, 0, 255),    # 13=Face
     (51, 170, 221),  # 14=LeftArm
     (0, 255, 255),   # 15=RightArm
     (85, 255, 170),  # 16=LeftLeg
     (170, 255, 85),  # 17=RightLeg
     (255, 255, 0),   # 18=LeftShoe
     (255, 170, 0)    # 19=RightShoe
     (170, 170, 50)   # 20=Skin/Neck/Chest (Newly added after running dataset_neck_skin_correction.py)
     ]
     """

    # load parsing image
    parse_name = im_name.replace('.jpg', '.png')
    im_parse = Image.open(
        # osp.join(self.data_path, 'image-parse', parse_name)).convert('L')
        osp.join(self.data_path, 'image-parse-new', parse_name)).convert('L')   # updated new segmentation
    parse_array = np.array(im_parse)
    im_mask = Image.open(
        osp.join(self.data_path, 'image-mask', parse_name)).convert('L')
    mask_array = np.array(im_mask)

    # parse_shape = (parse_array > 0).astype(np.float32)  # CP-VTON body shape
    # Get shape from body mask (CP-VTON+)
    parse_shape = (mask_array > 0).astype(np.float32)

    if self.stage == 'GMM':
        parse_head = (parse_array == 1).astype(np.float32) + \
            (parse_array == 4).astype(np.float32) + \
            (parse_array == 13).astype(
                np.float32)  # CP-VTON+ GMM input (reserved regions)
    else:
        parse_head = (parse_array == 1).astype(np.float32) + \
            (parse_array == 2).astype(np.float32) + \
            (parse_array == 4).astype(np.float32) + \
            (parse_array == 9).astype(np.float32) + \
            (parse_array == 12).astype(np.float32) + \
            (parse_array == 13).astype(np.float32) + \
            (parse_array == 16).astype(np.float32) + \
            (parse_array == 17).astype(
            np.float32)  # CP-VTON+ TOM input (reserved regions)

    parse_cloth = (parse_array == 5).astype(np.float32) + \
        (parse_array == 6).astype(np.float32) + \
        (parse_array == 7).astype(np.float32)    # upper-clothes labels

    # shape downsample
    parse_shape_ori = Image.fromarray((parse_shape*255).astype(np.uint8))
    parse_shape = parse_shape_ori.resize(
        (self.fine_width//16, self.fine_height//16), Image.BILINEAR)
    parse_shape = parse_shape.resize(
        (self.fine_width, self.fine_height), Image.BILINEAR)
    parse_shape_ori = parse_shape_ori.resize(
        (self.fine_width, self.fine_height), Image.BILINEAR)
    shape_ori = self.transform(parse_shape_ori)  # [-1,1]
    shape = self.transform(parse_shape)  # [-1,1]
    phead = torch.from_numpy(parse_head)  # [0,1]
    # phand = torch.from_numpy(parse_hand)  # [0,1]
    pcm = torch.from_numpy(parse_cloth)  # [0,1]

    # upper cloth
    im_c = im * pcm + (1 - pcm)  # [-1,1], fill 1 for other parts
    im_h = im * phead - (1 - phead)  # [-1,1], fill -1 for other parts

    # load pose points
    pose_name = im_name.replace('.jpg', '_keypoints.json')
    with open(osp.join(self.data_path, 'pose', pose_name), 'r') as f:
        pose_label = json.load(f)
        # pose_data = pose_label['people'][0]['pose_keypoints']
        pose_data = pose_label['people'][0]['pose_keypoints_2d']
        pose_data = np.array(pose_data)
        pose_data = pose_data.reshape((-1, 3))

    point_num = pose_data.shape[0]
    pose_map = torch.zeros(point_num, self.fine_height, self.fine_width)
    r = self.radius
    im_pose = Image.new('L', (self.fine_width, self.fine_height))
    pose_draw = ImageDraw.Draw(im_pose)
    for i in range(point_num):
        one_map = Image.new('L', (self.fine_width, self.fine_height))
        draw = ImageDraw.Draw(one_map)
        pointx = pose_data[i, 0]
        pointy = pose_data[i, 1]
        if pointx > 1 and pointy > 1:
            draw.rectangle((pointx-r, pointy-r, pointx +
                            r, pointy+r), 'white', 'white')
            pose_draw.rectangle(
                (pointx-r, pointy-r, pointx+r, pointy+r), 'white', 'white')
        one_map = self.transform(one_map)
        pose_map[i] = one_map[0]

    # just for visualization
    im_pose = self.transform(im_pose)

    # cloth-agnostic representation
    agnostic = torch.cat([shape, im_h, pose_map], 0)
    print("Shape :", shape.shape)
    print("im_h :", im_h.shape)
    print("pose_map :", pose_map.shape)

    if self.stage == 'GMM':
        im_g = Image.open('grid.png')
        im_g = self.transform(im_g)
    else:
        im_g = ''

    pcm.unsqueeze_(0)  # CP-VTON+

    result = {
        'c_name':   c_name,     # for visualization
        'im_name':  im_name,    # for visualization or ground truth
        'cloth':    c,          # for input
        'cloth_mask':     cm,   # for input
        'image':    im,         # for visualization
        'agnostic': agnostic,   # for input
        'parse_cloth': im_c,    # for ground truth
        'shape': shape,         # for visualization
        'head': im_h,           # for visualization
        'pose_image': im_pose,  # for visualization
        'grid_image': im_g,     # for visualization
        'parse_cloth_mask': pcm,     # for CP-VTON+, TOM input
        'shape_ori': shape_ori,     # original body shape without resize
    }

    return result

def __len__(self):
    return len(self.im_names)

class CPDataLoader(object): def init(self, opt, dataset): super(CPDataLoader, self).init()

    if opt.shuffle:
        train_sampler = torch.utils.data.sampler.RandomSampler(dataset)
    else:
        train_sampler = None

    self.data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=opt.batch_size, shuffle=(
            train_sampler is None),
        num_workers=opt.workers, pin_memory=True, sampler=train_sampler)
    self.dataset = dataset
    self.data_iter = self.data_loader.__iter__()

def next_batch(self):
    try:
        batch = self.data_iter.__next__()
    except StopIteration:
        self.data_iter = self.data_loader.__iter__()
        batch = self.data_iter.__next__()

    return batch

if name == "main": print("Check the dataset for geometric matching module!")

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--dataroot", default="data")
parser.add_argument("--datamode", default="train")
parser.add_argument("--stage", default="GMM")
parser.add_argument("--data_list", default="train_pairs.txt")
parser.add_argument("--fine_width", type=int, default=192)
parser.add_argument("--fine_height", type=int, default=256)
parser.add_argument("--radius", type=int, default=3)
parser.add_argument("--shuffle", action='store_true',
                    help='shuffle input data')
parser.add_argument('-b', '--batch-size', type=int, default=4)
parser.add_argument('-j', '--workers', type=int, default=1)

opt = parser.parse_args()
dataset = CPDataset(opt)
data_loader = CPDataLoader(opt, dataset)

print('Size of the dataset: %05d, dataloader: %04d'
      % (len(dataset), len(data_loader.data_loader)))
first_item = dataset.__getitem__(0)
first_batch = data_loader.next_batch()

from IPython import embed
embed()

error

0000

vinodbukya6 commented 2 years ago

Hi @lokesh0606, Your error is related to pose json file, Check pose keypoints extraction from OpenPose. Example file: {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[92.7477,30.0611,0.899314,98.3928,54.0358,0.898113,75.1456,51.932,0.772811,63.8703,88.5553,0.846992,63.8942,121.655,0.893302,120.981,58.2507,0.836377,123.076,92.7458,0.805378,127.271,127.266,0.870634,76.5765,125.175,0.582027,68.8339,173.802,0.742368,70.921,218.178,0.787068,106.166,127.285,0.634873,105.457,178.722,0.754785,110.37,232.965,0.77029,88.5551,25.1222,0.926259,97.7201,25.1486,0.968981,0,0,0,109.681,25.1539,0.957358],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]}

len(pose_keypoints_2d) is 54. Command: !cd openpose && ./build/examples/openpose/openpose.bin --image_dir /content/openpose/image_resize --model_pose COCO --display 0 --render_pose 0 --write_json /content/openpose/pose2