Open jameskuma opened 4 months ago
I try to use TalkSHOW code to visualize data but I get the bad result.
Do you know the reason? My code is as follow (from TalkSHOW/scripts/demo.py
):
lower_pose = torch.tensor(
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0747, -0.0158, -0.0152, -1.1826512813568115, 0.23866955935955048,
0.15146760642528534, -1.2604516744613647, -0.3160211145877838,
-0.1603458970785141, 1.1654603481292725, 0.0, 0.0, 1.2521806955337524, 0.041598282754421234, -0.06312154978513718,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
lower_pose_stand = torch.tensor([
8.9759e-04, 7.1074e-04, -5.9163e-06, 8.9759e-04, 7.1074e-04, -5.9163e-06,
3.0747, -0.0158, -0.0152,
-3.6665e-01, -8.8455e-03, 1.6113e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
-3.9716e-01, -4.0229e-02, -1.2637e-01,
7.9163e-01, 6.8519e-02, -1.5091e-01, 7.9163e-01, 6.8519e-02, -1.5091e-01,
7.8632e-01, -4.3810e-02, 1.4375e-02,
-1.0675e-01, 1.2635e-01, 1.6711e-02, -1.0675e-01, 1.2635e-01, 1.6711e-02, ])
def part2full(input, stand=False):
if stand:
lp = torch.zeros_like(lower_pose)
lp[6:9] = torch.tensor([3.0747, -0.0158, -0.0152])
lp = lp.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
else:
lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
input = torch.cat([input[:, :3],
lp[:, :15],
input[:, 3:6],
lp[:, 15:21],
input[:, 6:9],
lp[:, 21:27],
input[:, 9:12],
lp[:, 27:],
input[:, 12:]]
, dim=1)
return input
def main():
# * create smplex model
zelin_log.info('init smlpx model...')
dtype = torch.float64
smplx_path = './visualise/'
model_params = dict(model_path=smplx_path,
model_type='smplx',
create_global_orient=True,
create_body_pose=True,
create_betas=True,
num_betas=300,
create_left_hand_pose=True,
create_right_hand_pose=True,
use_pca=False,
flat_hand_mean=False,
create_expression=True,
num_expression_coeffs=100,
num_pca_comps=12,
create_jaw_pose=True,
create_leye_pose=True,
create_reye_pose=True,
create_transl=False,
dtype=dtype,
)
smplx_model = smplx.create(**model_params).to(device)
# * load smplx param
# this is DiffSHEG output
pred_smplx = np.load('results/talkshow_88/test_custom_audio/talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree/fixStart10/ckpt_e2599_ddim25_lastStepInterp/pid_1/Forrest_tts.npy')
pred_smplx = torch.from_numpy(pred_smplx).float().to(device)[0][:100]
pred_smplx = part2full(pred_smplx, stand=True)
# * pred_smplx size: [n_frames, param_dim]
import tqdm
vertices = []
betas = torch.zeros([1, 300], dtype=torch.float64).to(device)
for frame_ind in tqdm.tqdm(range(pred_smplx.shape[0]), desc='infer mesh vectices per frame'):
sample_output: SMPLOutput = smplx_model.forward(
betas=betas,
jaw_pose=pred_smplx[frame_ind][0:3].unsqueeze_(dim=0),
leye_pose=pred_smplx[frame_ind][3:6].unsqueeze_(dim=0),
reye_pose=pred_smplx[frame_ind][6:9].unsqueeze_(dim=0),
global_orient=pred_smplx[frame_ind][9:12].unsqueeze_(dim=0),
body_pose=pred_smplx[frame_ind][12:75].unsqueeze_(dim=0),
left_hand_pose=pred_smplx[frame_ind][75:120].unsqueeze_(dim=0),
right_hand_pose=pred_smplx[frame_ind][120:165].unsqueeze_(dim=0),
expression=pred_smplx[frame_ind][165:265].unsqueeze_(dim=0),
return_verts=True,
)
vertices.append(sample_output.vertices.detach().cpu().numpy().squeeze())
vertices = np.asarray(vertices)
print(vertices.shape)
# * debug Render
exp_dir = 'exp/speech2smplx'
os.makedirs(exp_dir, exist_ok=True)
num_frames = vertices.shape[0]
# dataset is inverse
vertices = vertices.reshape(vertices.shape[0], -1, 3)
vertices[:, :, 1] = -vertices[:, :, 1]
vertices[:, :, 2] = -vertices[:, :, 2]
width, height = 800, 1440
viewport_height = 1440
z_offset = 1.8
video_fname = 'demo'
os.makedirs(f'{exp_dir}/video_frames', exist_ok=True)
writer = cv2.VideoWriter(f'{exp_dir}/{video_fname}.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 30, (width, height), True)
center = np.mean(vertices[0], axis=0)
render_helper = pyrender.OffscreenRenderer(viewport_width=800, viewport_height=viewport_height)
class Struct(object):
def __init__(self, **kwargs):
for key, val in kwargs.items():
setattr(self, key, val)
path = os.path.join(os.getcwd(), 'visualise/smplx/SMPLX_NEUTRAL.npz')
model_data = np.load(path, allow_pickle=True)
data_struct = Struct(**model_data)
for i_frame in tqdm.tqdm(range(num_frames), desc='render debug image'):
vectice = vertices[i_frame]
# todo save vectice as npz
imgi = render_mesh_helper((vectice, data_struct.f), center, camera='o', r=render_helper, y=0.7, z_offset=z_offset)
imgi = imgi.astype(np.uint8)
# save image as frame
cv2.imwrite(f'{exp_dir}/video_frames/{i_frame:04d}.png', imgi)
# save image as video
writer.write(imgi)
writer.release()
if __name__ == '__main__':
main()
Hi James, you may want to pay attention to the code here: https://github.com/JeremyCJM/DiffSHEG/blob/3ebf3058f48cba3da9146afb7623e9ec1ab9e9a5/datasets/show.py#L146. The order of channels for pose should be carefully aligned with the pose in visualization code of TalkSHOW.
Owner
Hi James, you may want to pay attention to the code here:
. The order of channels for pose should be carefully aligned with the pose in visualization code of TalkSHOW.
Thank you for reply!
Yes, so what is the the order of channel of these output files? I mean I read these npy
files and find that they are [n_frame, 232]
where 232
is exactly same as the output of SHOW/TalkSHOW.
The order is important since I need input them for this function to get mesh:
pred_smplx = np.load('Forrest_tts.npy')
sample_output: SMPLOutput = smplx_model.forward(
betas=betas,
jaw_pose=pred_smplx[0][0:3].unsqueeze_(dim=0),
leye_pose=pred_smplx[0][3:6].unsqueeze_(dim=0),
reye_pose=pred_smplx[0][6:9].unsqueeze_(dim=0),
global_orient=pred_smplx[0][9:12].unsqueeze_(dim=0),
body_pose=pred_smplx[0][12:75].unsqueeze_(dim=0),
left_hand_pose=pred_smplx[0][75:120].unsqueeze_(dim=0),
right_hand_pose=pred_smplx[0][120:165].unsqueeze_(dim=0),
expression=pred_smplx[0][165:265].unsqueeze_(dim=0),
return_verts=True,
)
Owner
Hi James, you may want to pay attention to the code here: https://github.com/JeremyCJM/DiffSHEG/blob/3ebf3058f48cba3da9146afb7623e9ec1ab9e9a5/datasets/show.py#L146
. The order of channels for pose should be carefully aligned with the pose in visualization code of TalkSHOW.
Thank you for reply!
Yes, so what is the the order of channel of these output files? I mean I read these
npy
files and find that they are[n_frame, 232]
where232
is exactly same as the output of SHOW/TalkSHOW.The order is important since I need input them for this function to get mesh:
pred_smplx = np.load('Forrest_tts.npy') sample_output: SMPLOutput = smplx_model.forward( betas=betas, jaw_pose=pred_smplx[0][0:3].unsqueeze_(dim=0), leye_pose=pred_smplx[0][3:6].unsqueeze_(dim=0), reye_pose=pred_smplx[0][6:9].unsqueeze_(dim=0), global_orient=pred_smplx[0][9:12].unsqueeze_(dim=0), body_pose=pred_smplx[0][12:75].unsqueeze_(dim=0), left_hand_pose=pred_smplx[0][75:120].unsqueeze_(dim=0), right_hand_pose=pred_smplx[0][120:165].unsqueeze_(dim=0), expression=pred_smplx[0][165:265].unsqueeze_(dim=0), return_verts=True, )
Hello, did you render the result correctly?
Hi @jameskuma, this is my code to visualize the SHOW results, which is modified from the visualization code in TalkSHOW. Remember to specify the face_path and gesture_path arguments.
import os
import sys
# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
sys.path.append(os.getcwd())
from transformers import Wav2Vec2Processor
from glob import glob
import numpy as np
import json
import smplx as smpl
from nets import *
from trainer.options import parse_args
from data_utils import torch_data
from trainer.config import load_JsonConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
from visualise.rendering import RenderTool
import time
def init_model(model_name, model_path, args, config):
if model_name == 's2g_face':
generator = s2g_face(
args,
config,
)
elif model_name == 's2g_body_vq':
generator = s2g_body_vq(
args,
config,
)
elif model_name == 's2g_body_pixel':
generator = s2g_body_pixel(
args,
config,
)
else:
raise NotImplementedError
model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
if model_name == 'smplx_S2G':
generator.generator.load_state_dict(model_ckpt['generator']['generator'])
elif 'generator' in list(model_ckpt.keys()):
generator.load_state_dict(model_ckpt['generator'])
else:
model_ckpt = {'generator': model_ckpt}
generator.load_state_dict(model_ckpt)
return generator
def init_dataloader(data_root, speakers, args, config):
if data_root.endswith('.csv'):
raise NotImplementedError
else:
data_class = torch_data
if 'smplx' in config.Model.model_name or 's2g' in config.Model.model_name:
data_base = torch_data(
data_root=data_root,
speakers=speakers,
split='test',
limbscaling=False,
normalization=config.Data.pose.normalization,
norm_method=config.Data.pose.norm_method,
split_trans_zero=False,
num_pre_frames=config.Data.pose.pre_pose_length,
num_generate_length=config.Data.pose.generate_length,
num_frames=30,
aud_feat_win_size=config.Data.aud.aud_feat_win_size,
aud_feat_dim=config.Data.aud.aud_feat_dim,
feat_method=config.Data.aud.feat_method,
smplx=True,
audio_sr=22000,
convert_to_6d=config.Data.pose.convert_to_6d,
expression=config.Data.pose.expression,
config=config
)
else:
data_base = torch_data(
data_root=data_root,
speakers=speakers,
split='val',
limbscaling=False,
normalization=config.Data.pose.normalization,
norm_method=config.Data.pose.norm_method,
split_trans_zero=False,
num_pre_frames=config.Data.pose.pre_pose_length,
aud_feat_win_size=config.Data.aud.aud_feat_win_size,
aud_feat_dim=config.Data.aud.aud_feat_dim,
feat_method=config.Data.aud.feat_method
)
if config.Data.pose.normalization:
norm_stats_fn = os.path.join(os.path.dirname(args.model_path), "norm_stats.npy")
norm_stats = np.load(norm_stats_fn, allow_pickle=True)
data_base.data_mean = norm_stats[0]
data_base.data_std = norm_stats[1]
else:
norm_stats = None
data_base.get_dataset()
infer_set = data_base.all_dataset
infer_loader = data.DataLoader(data_base.all_dataset, batch_size=1, shuffle=False)
return infer_set, infer_loader, norm_stats
def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
vertices_list = []
poses_list = []
expression = torch.zeros([1, 50])
for i in result_list:
vertices = []
poses = []
for j in range(i.shape[0]):
output = smplx_model(betas=betas,
expression=i[j][165:265].unsqueeze_(dim=0) if exp else expression,
jaw_pose=i[j][0:3].unsqueeze_(dim=0),
leye_pose=i[j][3:6].unsqueeze_(dim=0),
reye_pose=i[j][6:9].unsqueeze_(dim=0),
global_orient=i[j][9:12].unsqueeze_(dim=0),
body_pose=i[j][12:75].unsqueeze_(dim=0),
left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
return_verts=True)
vertices.append(output.vertices.detach().cpu().numpy().squeeze())
# pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1)
pose = output.body_pose
poses.append(pose.detach().cpu())
vertices = np.asarray(vertices)
vertices_list.append(vertices)
poses = torch.cat(poses, dim=0)
poses_list.append(poses)
if require_pose:
return vertices_list, poses_list
else:
return vertices_list, None
global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
def infer(data_root, g_body, g_face, g_body2, exp_name, infer_loader, infer_set, device, norm_stats, smplx,
smplx_model, rendertool, args=None, config=None, face_path=None, gesture_path=None):
am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
am_sr = 16000
num_sample = 1
face = False
if face:
body_static = torch.zeros([1, 162], device='cuda')
body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
stand = False
j = 0
gt_0 = None
face_list = os.listdir(face_path)
face_list.sort()
gesture_list = os.listdir(gesture_path)
gesture_list.sort()
for idx, bat in enumerate(infer_loader):
poses_ = bat['poses'].to(torch.float32).to(device)
if poses_.shape[-1] == 300:
# import pdb; pdb.set_trace()
j = j + 1
if j > 1000:
continue
id = bat['speaker'].to('cuda') - 20
if config.Data.pose.expression:
expression = bat['expression'].to(device).to(torch.float32)
poses = torch.cat([poses_, expression], dim=1)
else:
poses = poses_
cur_wav_file = bat['aud_file'][0]
npy_file_name = 'visualise/video/' + config.Log.name + '/' + \
cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1] + '.npy'
if os.path.exists(npy_file_name):
continue
betas = bat['betas'][0].to(torch.float64).to('cuda')
# betas = torch.zeros([1, 300], dtype=torch.float64).to('cuda')
gt = poses.to('cuda').squeeze().transpose(1, 0)
if config.Data.pose.normalization: # false
gt = denormalize(gt, norm_stats[0], norm_stats[1]).squeeze(dim=0)
if config.Data.pose.convert_to_6d: # false
if config.Data.pose.expression:
gt_exp = gt[:, -100:]
gt = gt[:, :-100]
gt = gt.reshape(gt.shape[0], -1, 6)
gt = matrix_to_axis_angle(rotation_6d_to_matrix(gt)).reshape(gt.shape[0], -1)
gt = torch.cat([gt, gt_exp], -1)
if face: # false
gt = torch.cat([gt[:, :3], body_static.repeat(gt.shape[0], 1), gt[:, -100:]], dim=-1)
result_list = [gt]
# cur_wav_file = '.\\training_data\\1_song_(Vocals).wav'
############################ Prediction ############################
pred_face = np.load(os.path.join(face_path, face_list[idx]))
pred_face = torch.tensor(pred_face).squeeze().to('cuda')
pred_jaw = pred_face[:, :3]
pred_face = pred_face[:, 3:]
for i in range(num_sample):
pred_res = np.load(os.path.join(gesture_path,gesture_list[idx]))
pred = torch.tensor(pred_res).squeeze().to('cuda')
if pred.shape[0] < pred_face.shape[0]:
repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
pred = torch.cat([pred, repeat_frame], dim=0)
else:
pred = pred[:pred_face.shape[0], :]
# pred = torch.cat([pred, pred_face], dim=-1)
pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)
pred = part2full(pred, stand)
result_list.append(pred)
vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)
result_list = [res.to('cpu') for res in result_list]
dict = np.concatenate(result_list[1:], axis=0)
file_name = 'visualise/video/' + config.Log.name + '/' + \
cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
np.save(file_name, dict)
rendertool._render_sequences(cur_wav_file, vertices_list[1:], stand=stand, face=face)
def main():
parser = parse_args()
args = parser.parse_args()
device = torch.device(args.gpu)
torch.cuda.set_device(device)
config = load_JsonConfig(args.config_file)
face_model_name = args.face_model_name
face_model_path = args.face_model_path
body_model_name = args.body_model_name
body_model_path = args.body_model_path
smplx_path = './visualise/'
os.environ['smplx_npz_path'] = config.smplx_npz_path
os.environ['extra_joint_path'] = config.extra_joint_path
os.environ['j14_regressor_path'] = config.j14_regressor_path
print('init model...')
generator = init_model(body_model_name, body_model_path, args, config)
generator2 = None
generator_face = init_model(face_model_name, face_model_path, args, config)
print('init dataloader...')
infer_set, infer_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
print('init smlpx model...')
dtype = torch.float64
model_params = dict(model_path=smplx_path,
model_type='smplx',
create_global_orient=True,
create_body_pose=True,
create_betas=True,
num_betas=300,
create_left_hand_pose=True,
create_right_hand_pose=True,
use_pca=False,
flat_hand_mean=False,
create_expression=True,
num_expression_coeffs=100,
num_pca_comps=12,
create_jaw_pose=True,
create_leye_pose=True,
create_reye_pose=True,
create_transl=False,
# gender='ne',
dtype=dtype, )
smplx_model = smpl.create(**model_params).to('cuda')
if args.rename != None:
config.Log.name = args.rename
print('init rendertool...')
rendertool = RenderTool('visualise/video/' + config.Log.name)
infer(config.Data.data_root, generator, generator_face, generator2, args.exp_name, infer_loader, infer_set, device,
norm_stats, True, smplx_model, rendertool, args, config, face_path=args.face_path, gesture_path=args.gesture_path)
if __name__ == '__main__':
main()
Hello, I get similar results to @jameskuma
I tried to understand if there is a mismatch in parameters in DiffSHEG output and SHOW SMPLX model input but everything seems okay. Has anyone been able to find the right way to render SHOW results?
@JeremyCJM I tried running your code but I cannot figure out what the face_path and gesture_path are since the DiffSHEG model only gives one npy output. Also, not quite sure why it creates a dataset and loader for the whole talkSHOW dataset whilst infering one output. Can you help me use your code for a single inference from the .npy output DiffSHEG gives?
Any help in visualising would be appreciated!
https://github.com/user-attachments/assets/08c0cd0e-f1ed-43eb-94a5-1ed136528e32
Here is my code:
import os
import sys
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
sys.path.append(os.getcwd())
from transformers import Wav2Vec2Processor
from glob import glob
import numpy as np
import json
import smplx as smpl
from nets import *
from trainer.options import parse_args
from data_utils import torch_data
from trainer.config import load_JsonConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
from visualise.rendering import RenderTool
global device
device = 'cpu'
def init_model(model_name, model_path, args, config):
if model_name == 's2g_face':
generator = s2g_face(
args,
config,
)
elif model_name == 's2g_body_vq':
generator = s2g_body_vq(
args,
config,
)
elif model_name == 's2g_body_pixel':
generator = s2g_body_pixel(
args,
config,
)
elif model_name == 's2g_LS3DCG':
generator = LS3DCG(
args,
config,
)
else:
raise NotImplementedError
model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
if model_name == 'smplx_S2G':
generator.generator.load_state_dict(model_ckpt['generator']['generator'])
elif 'generator' in list(model_ckpt.keys()):
generator.load_state_dict(model_ckpt['generator'])
else:
model_ckpt = {'generator': model_ckpt}
generator.load_state_dict(model_ckpt)
return generator
def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
vertices_list = []
poses_list = []
expression = torch.zeros([1, 50])
for i in result_list:
vertices = []
poses = []
for j in range(i.shape[0]):
output = smplx_model(betas=betas,
expression=i[j][165:265].unsqueeze_(dim=0),
jaw_pose=i[j][0:3].unsqueeze_(dim=0),
leye_pose=i[j][3:6].unsqueeze_(dim=0),
reye_pose=i[j][6:9].unsqueeze_(dim=0),
global_orient=i[j][9:12].unsqueeze_(dim=0),
body_pose=i[j][12:75].unsqueeze_(dim=0),
left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
return_verts=True)
vertices.append(output.vertices.detach().cpu().numpy().squeeze())
# pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1)
pose = output.body_pose
poses.append(pose.detach().cpu())
vertices = np.asarray(vertices)
vertices_list.append(vertices)
poses = torch.cat(poses, dim=0)
poses_list.append(poses)
if require_pose:
return vertices_list, poses_list
else:
return vertices_list, None
global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
def infer(g_body, g_face, smplx_model, rendertool, config, args):
betas = torch.zeros([1, 300], dtype=torch.float64).to(device)
am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
am_sr = 16000
num_sample = args.num_sample
cur_wav_file = args.audio_file
id = args.id
face = args.only_face
stand = args.stand
if face:
body_static = torch.zeros([1, 162], device=device)
body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
# result_list = []
# pred_face = g_face.infer_on_audio(cur_wav_file,
# initial_pose=None,
# norm_stats=None,
# w_pre=False,
# # id=id,
# frame=None,
# am=am,
# am_sr=am_sr
# )
# pred_face = torch.tensor(pred_face).squeeze().to(device)
# # pred_face = torch.zeros([gt.shape[0], 105])
# if config.Data.pose.convert_to_6d:
# pred_jaw = pred_face[:, :6].reshape(pred_face.shape[0], -1, 6)
# pred_jaw = matrix_to_axis_angle(rotation_6d_to_matrix(pred_jaw)).reshape(pred_face.shape[0], -1)
# pred_face = pred_face[:, 6:]
# else:
# pred_jaw = pred_face[:, :3]
# pred_face = pred_face[:, 3:]
# id = torch.tensor([id], device=device)
# for i in range(num_sample):
# pred_res = g_body.infer_on_audio(cur_wav_file,
# initial_pose=None,
# norm_stats=None,
# txgfile=None,
# id=id,
# var=None,
# fps=30,
# w_pre=False
# )
# pred = torch.tensor(pred_res).squeeze().to(device)
# if pred.shape[0] < pred_face.shape[0]:
# repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
# pred = torch.cat([pred, repeat_frame], dim=0)
# else:
# pred = pred[:pred_face.shape[0], :]
# body_or_face = False
# if pred.shape[1] < 275:
# body_or_face = True
# if config.Data.pose.convert_to_6d:
# pred = pred.reshape(pred.shape[0], -1, 6)
# pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
# pred = pred.reshape(pred.shape[0], -1)
# if config.Model.model_name == 's2g_LS3DCG':
# pred = torch.cat([pred[:, :3], pred[:, 103:], pred[:, 3:103]], dim=-1)
# else:
# pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)
# # pred[:, 9:12] = global_orient
# pred = part2full(pred, stand)
# if face:
# pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
# # result_list[0] = poses2pred(result_list[0], stand)
# # if gt_0 is None:
# # gt_0 = gt
# # pred = pred2poses(pred, gt_0)
# # result_list[0] = poses2poses(result_list[0], gt_0)
# result_list.append(pred)
result_list = torch.from_numpy(np.load('../DiffSHEG/results/talkshow_88/test_custom_audio/talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree/fixStart10/ckpt_e2599_ddim25_lastStepInterp/pid_4/gesture/Forrest_tts.npy'))
result_list = part2full(result_list[0], stand=True).unsqueeze(0)
print(result_list.shape)
vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)
result_list = [res.to('cpu') for res in result_list]
dict = np.concatenate(result_list[:], axis=0)
file_name = 'visualise/video/' + config.Log.name + '/' + \
cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
np.save(file_name, dict)
rendertool._render_sequences(cur_wav_file, vertices_list, stand=stand, face=face, whole_body=args.whole_body)
def main():
parser = parse_args()
args = parser.parse_args()
# device = torch.device(args.gpu)
# torch.cuda.set_device(device)
config = load_JsonConfig(args.config_file)
face_model_name = args.face_model_name
face_model_path = args.face_model_path
body_model_name = args.body_model_name
body_model_path = args.body_model_path
smplx_path = './visualise/'
os.environ['smplx_npz_path'] = config.smplx_npz_path
os.environ['extra_joint_path'] = config.extra_joint_path
os.environ['j14_regressor_path'] = config.j14_regressor_path
print('init model...')
generator = init_model(body_model_name, body_model_path, args, config)
generator2 = None
generator_face = init_model(face_model_name, face_model_path, args, config)
print('init smlpx model...')
dtype = torch.float64
model_params = dict(model_path=smplx_path,
model_type='smplx',
create_global_orient=True,
create_body_pose=True,
create_betas=True,
num_betas=300,
create_left_hand_pose=True,
create_right_hand_pose=True,
use_pca=False,
flat_hand_mean=False,
create_expression=True,
num_expression_coeffs=100,
num_pca_comps=12,
create_jaw_pose=True,
create_leye_pose=True,
create_reye_pose=True,
create_transl=False,
# gender='ne',
dtype=dtype, )
smplx_model = smpl.create(**model_params).to(device)
print('init rendertool...')
rendertool = RenderTool('visualise/video/' + config.Log.name)
infer(generator, generator_face, smplx_model, rendertool, config, args)
if __name__ == '__main__':
main()
Dear author,
Thank you for this awesome work!
I run the
inference
part of this repo using SHOW dataset, and I only get a bunch of.npz
.However, how to visualize them with visualization tool in
TalkSHOW
. I mean which part of code should I used to visualize the results?Best regards