What a great work! I have a question about the visualization results.

lisaner000 commented 6 months ago

Through the demo I can get the result of the third line of the project page in the black background, so I want to get the result of the second line with the original background, what should I do?

yw0208 commented 6 months ago

Sorry, I didn't get your point. According to my default setting, you should get " the result of the second line with the original background" directly.

yw0208 commented 6 months ago

Specifically, if you want to control the background, you can refer to code. You can change the background by changing the parameter "img".

lisaner000 commented 6 months ago

yw0208 commented 6 months ago

Did you change the codes in demo.py? Because you should get the image with original background directly, if you didn't change anything. Please show me your demo.py.

lisaner000 commented 6 months ago

I just add "os.environ['PYOPENGL_PLATFORM'] = 'osmesa'" at the beginning of the demo.py. the following is my demo.py.

-- coding: utf-8 --

Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is

holder of all proprietary rights on this computer program.

You can only use this computer program if you have closed

a license agreement with MPG or you get the right to use the computer

program from someone who is authorized to grant you that right.

Any use of the computer program without a valid license is prohibited and

liable to prosecution.


Copyright©2019 Max-Planck-Gesellschaft zur Förderung

der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute

for Intelligent Systems. All rights reserved.


Contact: ps-license@tuebingen.mpg.de

import os os.environ['PYOPENGL_PLATFORM'] = 'osmesa' import os.path as osp from lib.core.config import BASE_DATA_DIR from lib.models.staf import STAF from lib.models.smpl import SMPL, SMPL_MODEL_DIR

os.environ['PYOPENGL_PLATFORM'] = 'egl'

import cv2 import time import torch import joblib import shutil import colorsys import argparse import random import numpy as np from pathlib import Path from tqdm import tqdm from multi_person_tracker import MPT from torch.utils.data import DataLoader

from lib.utils.renderer import Renderer from lib.dataset._dataset_demo import CropDataset, FeatureDataset from lib.utils.demo_utils import ( download_youtube_clip, convert_crop_cam_to_orig_img, prepare_rendering_results, video_to_images, images_to_video, )

MIN_NUM_FRAMES = 25 random.seed(1) torch.manual_seed(1) np.random.seed(1)

def main(args): device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

""" Prepare input video (images) """
video_file = args.vid_file
if video_file.startswith('https://www.youtube.com'):
    print(f"Donwloading YouTube video \'{video_file}\'")
    video_file = download_youtube_clip(video_file, '/tmp')
    if video_file is None:
        exit('Youtube url is not valid!')
    print(f"YouTube Video has been downloaded to {video_file}...")

if not os.path.isfile(video_file):
    exit(f"Input video \'{video_file}\' does not exist!")

output_path = osp.join('./output/demo_output', os.path.basename(video_file).replace('.mp4', ''))
Path(output_path).mkdir(parents=True, exist_ok=True)
image_folder, num_frames, img_shape = video_to_images(video_file, return_info=True)

print(f"Input video number of frames {num_frames}\n")
orig_height, orig_width = img_shape[:2]

""" Run tracking """
total_time = time.time()
bbox_scale = 1.1  #
# run multi object tracker
mot = MPT(
tracking_results = mot(image_folder)

# remove tracklets if num_frames is less than MIN_NUM_FRAMES
for person_id in list(tracking_results.keys()):
    if tracking_results[person_id]['frames'].shape[0] < MIN_NUM_FRAMES:
        del tracking_results[person_id]

""" Get STAF model """
seq_len = 9
model = STAF(

# Load pretrained weights
pretrained_file = args.model
ckpt = torch.load(pretrained_file)
print(f"Load pretrained weights from \'{pretrained_file}\'")
ckpt = ckpt['gen_state_dict']
model.load_state_dict(ckpt, strict=False)

# Change mesh gender
gender = args.gender  # 'neutral', 'male', 'female'
model.regressor.smpl = SMPL(


# Get feature_extractor
from lib.models.pymaf_resnet import PyMAF_ResNet
hmr = PyMAF_ResNet().to(device)
checkpoint = torch.load('data/base_data/base_model.pt')
hmr.load_state_dict(checkpoint['model'], strict=False)

""" Run STAF on each person """
print("\nRunning STAF on each person tracklet...")
running_time = time.time()
running_results = {}
for person_id in tqdm(list(tracking_results.keys())):
    bboxes = joints2d = None
    bboxes = tracking_results[person_id]['bbox']
    frames = tracking_results[person_id]['frames']
    # Prepare static image features
    dataset = CropDataset(

    bboxes = dataset.bboxes
    frames = dataset.frames
    has_keypoints = True if joints2d is not None else False

    crop_dataloader = DataLoader(dataset, batch_size=256, num_workers=0)  # 16

    with torch.no_grad():
        feature_list = []
        for i, batch in enumerate(crop_dataloader):
            if has_keypoints:
                batch, nj2d = batch
                norm_joints2d.append(nj2d.numpy().reshape(-1, 21, 3))

            batch = batch.to(device)
            feature = hmr(batch.reshape(-1, 3, 224, 224))

        del batch

        feature_list = torch.cat(feature_list, dim=0)

    # Encode temporal features and estimate 3D human mesh
    dataset = FeatureDataset(
    dataset.feature_list = feature_list

    dataloader = DataLoader(dataset, batch_size=16, num_workers=0)  # 32
    with torch.no_grad():
        pred_cam, pred_verts, pred_pose, pred_betas, pred_joints3d, norm_joints2d = [], [], [], [], [], []

        for i, batch in enumerate(dataloader):
            if has_keypoints:
                batch, nj2d = batch
                norm_joints2d.append(nj2d.numpy().reshape(-1, 21, 3))

            batch = batch.to(device)
            output = model(batch)[0][-1]

            pred_cam.append(output['theta'][:, :3])
            pred_pose.append(output['theta'][:, 3:75])
            pred_betas.append(output['theta'][:, 75:])

        pred_cam = torch.cat(pred_cam, dim=0)
        pred_verts = torch.cat(pred_verts, dim=0)
        pred_pose = torch.cat(pred_pose, dim=0)
        pred_betas = torch.cat(pred_betas, dim=0)
        pred_joints3d = torch.cat(pred_joints3d, dim=0)

        del batch

    # ========= Save results to a pickle file ========= #
    pred_cam = pred_cam.cpu().numpy()
    pred_verts = pred_verts.cpu().numpy()
    pred_pose = pred_pose.cpu().numpy()
    pred_betas = pred_betas.cpu().numpy()
    pred_joints3d = pred_joints3d.cpu().numpy()
    bboxes[:, 2:] = bboxes[:, 2:] * 1.2
    # if args.render_plain:
    #     pred_cam[:, 0], pred_cam[:, 1:] = 1, 0  # np.array([[1, 0, 0]])
    orig_cam = convert_crop_cam_to_orig_img(

    output_dict = {
        'pred_cam': pred_cam,
        'orig_cam': orig_cam,
        'verts': pred_verts,
        'pose': pred_pose,
        'betas': pred_betas,
        'joints3d': pred_joints3d,
        'joints2d': joints2d,
        'bboxes': bboxes,
        'frame_ids': frames,

    running_results[person_id] = output_dict

del model

end = time.time()
fps = num_frames / (end - running_time)
print(f'STAF FPS: {fps:.2f}')
total_time = time.time() - total_time
print(f'Total time spent: {total_time:.2f} seconds (including model loading time).')
print(f'Total FPS (including model loading time): {num_frames / total_time:.2f}.')

if args.save_pkl:
    print(f"Saving output results to \'{os.path.join(output_path, 'staf_output.pkl')}\'.")
    joblib.dump(running_results, os.path.join(output_path, "staf_output.pkl"))

""" Render results as a single video """
renderer = Renderer(resolution=(orig_width, orig_height), orig_img=True, wireframe=args.wireframe)

output_img_folder = f'{image_folder}_output'
input_img_folder = f'{image_folder}_input'
os.makedirs(output_img_folder, exist_ok=True)
os.makedirs(input_img_folder, exist_ok=True)

print(f"\nRendering output video, writing frames to {output_img_folder}")
# prepare results for rendering
frame_results = prepare_rendering_results(running_results, num_frames)
mesh_color = {k: colorsys.hsv_to_rgb(np.random.rand(), 0.5, 1.0) for k in running_results.keys()}

image_file_names = sorted([
    os.path.join(image_folder, x)
    for x in os.listdir(image_folder)
    if x.endswith('.png') or x.endswith('.jpg')

for frame_idx in tqdm(range(len(image_file_names))):
    img_fname = image_file_names[frame_idx]
    img = cv2.imread(img_fname)
    input_img = img.copy()
    if args.render_plain:
        img[:] = 0

    if args.sideview:
        side_img = np.zeros_like(img)

    for person_id, person_data in frame_results[frame_idx].items():
        frame_verts = person_data['verts']
        frame_cam = person_data['cam']

        mesh_filename = None
        if args.save_obj:
            mesh_folder = os.path.join(output_path, 'meshes', f'{person_id:04d}')
            Path(mesh_folder).mkdir(parents=True, exist_ok=True)
            mesh_filename = os.path.join(mesh_folder, f'{frame_idx:06d}.obj')

        mc = mesh_color[person_id]

        img = renderer.render(
        if args.sideview:
            side_img = renderer.render(
                axis=[0, 1, 0],

    if args.sideview:
        img = np.concatenate([img, side_img], axis=1)

    # save output frames
    cv2.imwrite(os.path.join(output_img_folder, f'{frame_idx:06d}.jpg'), img)
    cv2.imwrite(os.path.join(input_img_folder, f'{frame_idx:06d}.jpg'), input_img)

    if args.display:
        cv2.imshow('Video', img)
        if cv2.waitKey(1) & 0xFF == ord('q'):

if args.display:

""" Save rendered video """
vid_name = os.path.basename(video_file)
save_name = f'MPS-Net_{vid_name.replace(".mp4", "")}_output.mp4'
save_path = os.path.join(output_path, save_name)

images_to_video(img_folder=output_img_folder, output_vid_file=save_path)
images_to_video(img_folder=input_img_folder, output_vid_file=os.path.join(output_path, vid_name))
print(f"Saving result video to {os.path.abspath(save_path)}")

if name == 'main': parser = argparse.ArgumentParser()

parser.add_argument('--vid_file', type=str, default='demo_video.mp4', help='input video path or youtube link')
#parser.add_argument('--vid_file', type=str, default='dancing_video3.mp4', help='input video path or youtube link')

parser.add_argument('--model', type=str, default='./data/base_data/3dpw_model_best.pth.tar',
                    help='path to pretrained model weight')

parser.add_argument('--detector', type=str, default='yolo', choices=['yolo', 'maskrcnn'],
                    help='object detector to be used for bbox tracking')

parser.add_argument('--yolo_img_size', type=int, default=416,
                    help='input image size for yolo detector')

parser.add_argument('--tracker_batch_size', type=int, default=12,
                    help='batch size of object detector used for bbox tracking')

parser.add_argument('--display', action='store_true',
                    help='visualize the results of each step during demo')

parser.add_argument('--save_pkl', action='store_true',
                    help='save results to a pkl file')

parser.add_argument('--save_obj', action='store_true',
                    help='save results as .obj files.')

parser.add_argument('--gender', type=str, default='neutral',
                    help='set gender of people from (neutral, male, female)')

parser.add_argument('--wireframe', action='store_true',
                    help='render all meshes as wireframes.')

parser.add_argument('--sideview', action='store_true',
                    help='render meshes from alternate viewpoint.')

parser.add_argument('--render_plain', action='store_true',
                    help='render meshes on plain background')

parser.add_argument('--gpu', type=int, default='1', help='gpu num')

args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

yw0208 commented 6 months ago

Can you delete “os.environ['PYOPENGL_PLATFORM'] = 'osmesa‘”? I believe it may be the reason.

lisaner000 commented 6 months ago

Oh, so you are running on server? Maybe you can try this, i.e., os.environ['PYOPENGL_PLATFORM'] = 'egl‘