boreshkinai / hybrik-transformer

22 stars 1 forks source link

Demo_image.py for hrnet pipeline? #3

Open Dipankar1997161 opened 6 months ago

Dipankar1997161 commented 6 months ago

Hello @boreshkinai,

I saw your demo_image.py code and it uses resnet and a corresponding config file. I was wondering, how can we do the same for HRNet backbone?

I treid to use the pretrained models with the config file "hybrik_transformer_hrnet_smpl24.yaml" But i encountered an erorr.

error error_2

Here is my demo_image.py

"""Image demo script."""
import argparse
import os

import cv2
import numpy as np
import torch
from easydict import EasyDict as edict
from hybrik.models import builder
from hybrik.utils.config import update_config
from hybrik.utils.presets import SimpleTransform3DSMPL
from hybrik.utils.render import SMPLRenderer
from hybrik.utils.vis import get_one_box, vis_smpl_3d
from torchvision import transforms as T
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from tqdm import tqdm

det_transform = T.Compose([T.ToTensor()])

def xyxy2xywh(bbox):
    x1, y1, x2, y2 = bbox

    cx = (x1 + x2) / 2
    cy = (y1 + y2) / 2
    w = x2 - x1
    h = y2 - y1
    return [cx, cy, w, h]

parser = argparse.ArgumentParser(description='HybrIK Demo')
CKPT = 'pretrained_models/average_model_hrnet_w_pw3d_181_199.pth'
#CKPT = "/home/ndip/hybrik_transformer_hrnet_smpl24.yaml-NTOP/model_3.pth"

parser.add_argument('--gpu',
                    help='gpu',
                    default=0,
                    type=int)
parser.add_argument('--checkpoint',
                    help='checkpoint file name',
                    required=True,
                    type=str)
parser.add_argument('--img-dir',
                    help='image folder',
                    default='',
                    type=str)
parser.add_argument('--out-dir',
                    help='output folder',
                    default='',
                    type=str)
opt = parser.parse_args()

#cfg_file = 'configs/256x192_adam_lr1e-3-res34_smpl_3d_cam_2x_mix.yaml'
cfg_file = 'configs/hybrik_transformer_hrnet_smpl24.yaml'
cfg = update_config(cfg_file)

dummpy_set = edict({
    'joint_pairs_17': None,
    'joint_pairs_24': None,
    'joint_pairs_29': None,
    'bbox_3d_shape': (2.2, 2.2, 2.2)
})

transformation = SimpleTransform3DSMPL(
    dummpy_set, scale_factor=cfg.DATASET.SCALE_FACTOR,
    color_factor=cfg.DATASET.COLOR_FACTOR,
    occlusion=cfg.DATASET.OCCLUSION,
    input_size=cfg.MODEL.IMAGE_SIZE,
    output_size=cfg.MODEL.HEATMAP_SIZE,
    depth_dim=cfg.MODEL.EXTRA.DEPTH_DIM,
    bbox_3d_shape=(2.2, 2,2, 2.2),
    rot=cfg.DATASET.ROT_FACTOR, sigma=cfg.MODEL.EXTRA.SIGMA,
    train=False, add_dpg=False,
    loss_type=cfg.LOSS['TYPE'])

det_model = fasterrcnn_resnet50_fpn(pretrained=True)

hybrik_model = builder.build_sppe(cfg.MODEL)

print(f'Loading model from {opt.checkpoint}...')

#hybrik_model.load_state_dict(torch.load(opt.checkpoint, map_location='cpu'), strict=False)
hybrik_model.load_state_dict(torch.load(CKPT, map_location='cpu'), strict=False)

det_model.cuda(opt.gpu)
hybrik_model.cuda(opt.gpu)
det_model.eval()
hybrik_model.eval()

files = os.listdir(opt.img_dir)

if not os.path.exists(opt.out_dir):
    os.makedirs(opt.out_dir)

for file in tqdm(files):
    if not os.path.isdir(file) and file[-4:] in ['.jpg', '.png']:
        # is an image
        if file[:4] == 'res_':
            continue

        # process file name
        img_path = os.path.join(opt.img_dir, file)
        dirname = os.path.dirname(img_path)
        basename = os.path.basename(img_path)

        # Run Detection
        input_image = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        det_input = det_transform(input_image).to(opt.gpu)
        det_output = det_model([det_input])[0]

        tight_bbox = get_one_box(det_output)  # xyxy

        # Run HybrIK
        #print(transformation.test_transform(img_path, tight_bbox))
        pose_input, bbox = transformation.test_transform(img_path, tight_bbox)
        pose_input = pose_input.to(opt.gpu)[None, :, :, :]

        pose_output = hybrik_model(pose_input)

        # Visualization
        image = input_image
        img_size = (image.shape[0], image.shape[1])
        focal = np.array([1000, 1000])
        bbox = xyxy2xywh(bbox)
        princpt = [bbox[0], bbox[1]]

        renderer = SMPLRenderer(faces=hybrik_model.smpl.faces,
                                img_size=img_size, focal=focal,
                                princpt=princpt)

        transl = pose_output.transl.detach().cpu().numpy().squeeze()
        transl[2] = transl[2] * 256 / bbox[2]

        image_vis = vis_smpl_3d(
            pose_output, image, cam_root=transl,
            f=focal, c=princpt, renderer=renderer)

        image_vis = cv2.cvtColor(image_vis, cv2.COLOR_RGB2BGR)

        res_path = os.path.join(opt.out_dir, basename)
        cv2.imwrite(res_path, image_vis)
        '''

It would be great if you can help in this.