JeremyCJM / DiffSHEG

[CVPR'24] DiffSHEG: A Diffusion-Based Approach for Real-Time Speech-driven Holistic 3D Expression and Gesture Generation
https://jeremycjm.github.io/proj/DiffSHEG/
BSD 3-Clause "New" or "Revised" License
121 stars 10 forks source link

[SHOW Visualization] Which part of code to refer #10

Open jameskuma opened 4 months ago

jameskuma commented 4 months ago

Dear author,

Thank you for this awesome work!

I run the inference part of this repo using SHOW dataset, and I only get a bunch of .npz.

However, how to visualize them with visualization tool in TalkSHOW. I mean which part of code should I used to visualize the results?

Best regards

jameskuma commented 4 months ago

I try to use TalkSHOW code to visualize data but I get the bad result.

image

Do you know the reason? My code is as follow (from TalkSHOW/scripts/demo.py):


lower_pose = torch.tensor(
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0747, -0.0158, -0.0152, -1.1826512813568115, 0.23866955935955048,
     0.15146760642528534, -1.2604516744613647, -0.3160211145877838,
     -0.1603458970785141, 1.1654603481292725, 0.0, 0.0, 1.2521806955337524, 0.041598282754421234, -0.06312154978513718,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
lower_pose_stand = torch.tensor([
    8.9759e-04, 7.1074e-04, -5.9163e-06, 8.9759e-04, 7.1074e-04, -5.9163e-06,
    3.0747, -0.0158, -0.0152,
    -3.6665e-01, -8.8455e-03, 1.6113e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
    -3.9716e-01, -4.0229e-02, -1.2637e-01,
    7.9163e-01, 6.8519e-02, -1.5091e-01, 7.9163e-01, 6.8519e-02, -1.5091e-01,
    7.8632e-01, -4.3810e-02, 1.4375e-02,
    -1.0675e-01, 1.2635e-01, 1.6711e-02, -1.0675e-01, 1.2635e-01, 1.6711e-02, ])

def part2full(input, stand=False):
    if stand:
        lp = torch.zeros_like(lower_pose)
        lp[6:9] = torch.tensor([3.0747, -0.0158, -0.0152])
        lp = lp.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
    else:
        lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)

    input = torch.cat([input[:, :3],
                       lp[:, :15],
                       input[:, 3:6],
                       lp[:, 15:21],
                       input[:, 6:9],
                       lp[:, 21:27],
                       input[:, 9:12],
                       lp[:, 27:],
                       input[:, 12:]]
                      , dim=1)
    return input

def main():
    # * create smplex model
    zelin_log.info('init smlpx model...')
    dtype = torch.float64
    smplx_path = './visualise/'
    model_params = dict(model_path=smplx_path,
                        model_type='smplx',
                        create_global_orient=True,
                        create_body_pose=True,
                        create_betas=True,
                        num_betas=300,
                        create_left_hand_pose=True,
                        create_right_hand_pose=True,
                        use_pca=False,
                        flat_hand_mean=False,
                        create_expression=True,
                        num_expression_coeffs=100,
                        num_pca_comps=12,
                        create_jaw_pose=True,
                        create_leye_pose=True,
                        create_reye_pose=True,
                        create_transl=False,
                        dtype=dtype,
    )
    smplx_model = smplx.create(**model_params).to(device)
    # * load smplx param
    # this is DiffSHEG output
    pred_smplx = np.load('results/talkshow_88/test_custom_audio/talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree/fixStart10/ckpt_e2599_ddim25_lastStepInterp/pid_1/Forrest_tts.npy')
    pred_smplx = torch.from_numpy(pred_smplx).float().to(device)[0][:100]
    pred_smplx = part2full(pred_smplx, stand=True)

    # * pred_smplx size: [n_frames, param_dim]
    import tqdm
    vertices = []
    betas = torch.zeros([1, 300], dtype=torch.float64).to(device)
    for frame_ind in tqdm.tqdm(range(pred_smplx.shape[0]), desc='infer mesh vectices per frame'):
        sample_output: SMPLOutput = smplx_model.forward(
            betas=betas,
            jaw_pose=pred_smplx[frame_ind][0:3].unsqueeze_(dim=0),
            leye_pose=pred_smplx[frame_ind][3:6].unsqueeze_(dim=0),
            reye_pose=pred_smplx[frame_ind][6:9].unsqueeze_(dim=0),
            global_orient=pred_smplx[frame_ind][9:12].unsqueeze_(dim=0),
            body_pose=pred_smplx[frame_ind][12:75].unsqueeze_(dim=0),
            left_hand_pose=pred_smplx[frame_ind][75:120].unsqueeze_(dim=0),
            right_hand_pose=pred_smplx[frame_ind][120:165].unsqueeze_(dim=0),
            expression=pred_smplx[frame_ind][165:265].unsqueeze_(dim=0),
            return_verts=True,
        )
        vertices.append(sample_output.vertices.detach().cpu().numpy().squeeze())
    vertices = np.asarray(vertices)

    print(vertices.shape)

    # * debug Render
    exp_dir = 'exp/speech2smplx'
    os.makedirs(exp_dir, exist_ok=True)
    num_frames = vertices.shape[0]

    # dataset is inverse
    vertices = vertices.reshape(vertices.shape[0], -1, 3)
    vertices[:, :, 1] = -vertices[:, :, 1]
    vertices[:, :, 2] = -vertices[:, :, 2]

    width, height = 800, 1440
    viewport_height = 1440
    z_offset = 1.8

    video_fname = 'demo'
    os.makedirs(f'{exp_dir}/video_frames', exist_ok=True)

    writer = cv2.VideoWriter(f'{exp_dir}/{video_fname}.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 30, (width, height), True)
    center = np.mean(vertices[0], axis=0)

    render_helper = pyrender.OffscreenRenderer(viewport_width=800, viewport_height=viewport_height)

    class Struct(object):
        def __init__(self, **kwargs):
            for key, val in kwargs.items():
                setattr(self, key, val)

    path = os.path.join(os.getcwd(), 'visualise/smplx/SMPLX_NEUTRAL.npz')
    model_data = np.load(path, allow_pickle=True)
    data_struct = Struct(**model_data)

    for i_frame in tqdm.tqdm(range(num_frames), desc='render debug image'):
        vectice = vertices[i_frame]
        # todo save vectice as npz
        imgi = render_mesh_helper((vectice, data_struct.f), center, camera='o', r=render_helper, y=0.7, z_offset=z_offset)
        imgi = imgi.astype(np.uint8)
        # save image as frame
        cv2.imwrite(f'{exp_dir}/video_frames/{i_frame:04d}.png', imgi)
        # save image as video
        writer.write(imgi)
    writer.release()

if __name__ == '__main__':
    main()
JeremyCJM commented 4 months ago

Hi James, you may want to pay attention to the code here: https://github.com/JeremyCJM/DiffSHEG/blob/3ebf3058f48cba3da9146afb7623e9ec1ab9e9a5/datasets/show.py#L146. The order of channels for pose should be carefully aligned with the pose in visualization code of TalkSHOW.

jameskuma commented 4 months ago

Owner

Hi James, you may want to pay attention to the code here:

https://github.com/JeremyCJM/DiffSHEG/blob/3ebf3058f48cba3da9146afb7623e9ec1ab9e9a5/datasets/show.py#L146

. The order of channels for pose should be carefully aligned with the pose in visualization code of TalkSHOW.

Thank you for reply!

Yes, so what is the the order of channel of these output files? I mean I read these npy files and find that they are [n_frame, 232] where 232 is exactly same as the output of SHOW/TalkSHOW.

The order is important since I need input them for this function to get mesh:

pred_smplx = np.load('Forrest_tts.npy')

sample_output: SMPLOutput = smplx_model.forward(
    betas=betas,
    jaw_pose=pred_smplx[0][0:3].unsqueeze_(dim=0),
    leye_pose=pred_smplx[0][3:6].unsqueeze_(dim=0),
    reye_pose=pred_smplx[0][6:9].unsqueeze_(dim=0),
    global_orient=pred_smplx[0][9:12].unsqueeze_(dim=0),
    body_pose=pred_smplx[0][12:75].unsqueeze_(dim=0),
    left_hand_pose=pred_smplx[0][75:120].unsqueeze_(dim=0),
    right_hand_pose=pred_smplx[0][120:165].unsqueeze_(dim=0),
    expression=pred_smplx[0][165:265].unsqueeze_(dim=0),
    return_verts=True,
)
Mumuwei commented 4 months ago

Owner

Hi James, you may want to pay attention to the code here: https://github.com/JeremyCJM/DiffSHEG/blob/3ebf3058f48cba3da9146afb7623e9ec1ab9e9a5/datasets/show.py#L146

. The order of channels for pose should be carefully aligned with the pose in visualization code of TalkSHOW.

Thank you for reply!

Yes, so what is the the order of channel of these output files? I mean I read these npy files and find that they are [n_frame, 232] where 232 is exactly same as the output of SHOW/TalkSHOW.

The order is important since I need input them for this function to get mesh:

pred_smplx = np.load('Forrest_tts.npy')

sample_output: SMPLOutput = smplx_model.forward(
    betas=betas,
    jaw_pose=pred_smplx[0][0:3].unsqueeze_(dim=0),
    leye_pose=pred_smplx[0][3:6].unsqueeze_(dim=0),
    reye_pose=pred_smplx[0][6:9].unsqueeze_(dim=0),
    global_orient=pred_smplx[0][9:12].unsqueeze_(dim=0),
    body_pose=pred_smplx[0][12:75].unsqueeze_(dim=0),
    left_hand_pose=pred_smplx[0][75:120].unsqueeze_(dim=0),
    right_hand_pose=pred_smplx[0][120:165].unsqueeze_(dim=0),
    expression=pred_smplx[0][165:265].unsqueeze_(dim=0),
    return_verts=True,
)

Hello, did you render the result correctly?

JeremyCJM commented 3 months ago

Hi @jameskuma, this is my code to visualize the SHOW results, which is modified from the visualization code in TalkSHOW. Remember to specify the face_path and gesture_path arguments.

import os
import sys

# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
sys.path.append(os.getcwd())

from transformers import Wav2Vec2Processor
from glob import glob

import numpy as np
import json
import smplx as smpl

from nets import *
from trainer.options import parse_args
from data_utils import torch_data
from trainer.config import load_JsonConfig

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
from visualise.rendering import RenderTool

import time

def init_model(model_name, model_path, args, config):
    if model_name == 's2g_face':
        generator = s2g_face(
            args,
            config,
        )
    elif model_name == 's2g_body_vq':
        generator = s2g_body_vq(
            args,
            config,
        )
    elif model_name == 's2g_body_pixel':
        generator = s2g_body_pixel(
            args,
            config,
        )
    else:
        raise NotImplementedError

    model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
    if model_name == 'smplx_S2G':
        generator.generator.load_state_dict(model_ckpt['generator']['generator'])

    elif 'generator' in list(model_ckpt.keys()):
        generator.load_state_dict(model_ckpt['generator'])
    else:
        model_ckpt = {'generator': model_ckpt}
        generator.load_state_dict(model_ckpt)

    return generator

def init_dataloader(data_root, speakers, args, config):
    if data_root.endswith('.csv'):
        raise NotImplementedError
    else:
        data_class = torch_data
    if 'smplx' in config.Model.model_name or 's2g' in config.Model.model_name:
        data_base = torch_data(
            data_root=data_root,
            speakers=speakers,
            split='test',
            limbscaling=False,
            normalization=config.Data.pose.normalization,
            norm_method=config.Data.pose.norm_method,
            split_trans_zero=False,
            num_pre_frames=config.Data.pose.pre_pose_length,
            num_generate_length=config.Data.pose.generate_length,
            num_frames=30,
            aud_feat_win_size=config.Data.aud.aud_feat_win_size,
            aud_feat_dim=config.Data.aud.aud_feat_dim,
            feat_method=config.Data.aud.feat_method,
            smplx=True,
            audio_sr=22000,
            convert_to_6d=config.Data.pose.convert_to_6d,
            expression=config.Data.pose.expression,
            config=config
        )
    else:
        data_base = torch_data(
            data_root=data_root,
            speakers=speakers,
            split='val',
            limbscaling=False,
            normalization=config.Data.pose.normalization,
            norm_method=config.Data.pose.norm_method,
            split_trans_zero=False,
            num_pre_frames=config.Data.pose.pre_pose_length,
            aud_feat_win_size=config.Data.aud.aud_feat_win_size,
            aud_feat_dim=config.Data.aud.aud_feat_dim,
            feat_method=config.Data.aud.feat_method
        )
    if config.Data.pose.normalization:
        norm_stats_fn = os.path.join(os.path.dirname(args.model_path), "norm_stats.npy")
        norm_stats = np.load(norm_stats_fn, allow_pickle=True)
        data_base.data_mean = norm_stats[0]
        data_base.data_std = norm_stats[1]
    else:
        norm_stats = None

    data_base.get_dataset()
    infer_set = data_base.all_dataset
    infer_loader = data.DataLoader(data_base.all_dataset, batch_size=1, shuffle=False)

    return infer_set, infer_loader, norm_stats

def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
    vertices_list = []
    poses_list = []
    expression = torch.zeros([1, 50])

    for i in result_list:
        vertices = []
        poses = []
        for j in range(i.shape[0]):
            output = smplx_model(betas=betas,
                                 expression=i[j][165:265].unsqueeze_(dim=0) if exp else expression,
                                 jaw_pose=i[j][0:3].unsqueeze_(dim=0),
                                 leye_pose=i[j][3:6].unsqueeze_(dim=0),
                                 reye_pose=i[j][6:9].unsqueeze_(dim=0),
                                 global_orient=i[j][9:12].unsqueeze_(dim=0),
                                 body_pose=i[j][12:75].unsqueeze_(dim=0),
                                 left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
                                 right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
                                 return_verts=True)
            vertices.append(output.vertices.detach().cpu().numpy().squeeze())
            # pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1)
            pose = output.body_pose
            poses.append(pose.detach().cpu())
        vertices = np.asarray(vertices)
        vertices_list.append(vertices)
        poses = torch.cat(poses, dim=0)
        poses_list.append(poses)
    if require_pose:
        return vertices_list, poses_list
    else:
        return vertices_list, None

global_orient = torch.tensor([3.0747, -0.0158, -0.0152])

def infer(data_root, g_body, g_face, g_body2, exp_name, infer_loader, infer_set, device, norm_stats, smplx,
          smplx_model, rendertool, args=None, config=None, face_path=None, gesture_path=None):
    am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
    am_sr = 16000
    num_sample = 1
    face = False
    if face:
        body_static = torch.zeros([1, 162], device='cuda')
        body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
    stand = False
    j = 0
    gt_0 = None

    face_list = os.listdir(face_path)
    face_list.sort()

    gesture_list = os.listdir(gesture_path)
    gesture_list.sort()

    for idx, bat in enumerate(infer_loader):
        poses_ = bat['poses'].to(torch.float32).to(device)
        if poses_.shape[-1] == 300:
            # import pdb; pdb.set_trace()
            j = j + 1
            if j > 1000:
                continue
            id = bat['speaker'].to('cuda') - 20
            if config.Data.pose.expression:
                expression = bat['expression'].to(device).to(torch.float32)
                poses = torch.cat([poses_, expression], dim=1)
            else:
                poses = poses_
            cur_wav_file = bat['aud_file'][0]
            npy_file_name = 'visualise/video/' + config.Log.name + '/' + \
                        cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1] + '.npy'

            if os.path.exists(npy_file_name):
                continue

            betas = bat['betas'][0].to(torch.float64).to('cuda')
            # betas = torch.zeros([1, 300], dtype=torch.float64).to('cuda')
            gt = poses.to('cuda').squeeze().transpose(1, 0)
            if config.Data.pose.normalization: # false
                gt = denormalize(gt, norm_stats[0], norm_stats[1]).squeeze(dim=0)
            if config.Data.pose.convert_to_6d: # false
                if config.Data.pose.expression:
                    gt_exp = gt[:, -100:]
                    gt = gt[:, :-100]

                gt = gt.reshape(gt.shape[0], -1, 6)

                gt = matrix_to_axis_angle(rotation_6d_to_matrix(gt)).reshape(gt.shape[0], -1)
                gt = torch.cat([gt, gt_exp], -1)
            if face: # false
                gt = torch.cat([gt[:, :3], body_static.repeat(gt.shape[0], 1), gt[:, -100:]], dim=-1)

            result_list = [gt]

            # cur_wav_file = '.\\training_data\\1_song_(Vocals).wav'

            ############################ Prediction ############################
            pred_face = np.load(os.path.join(face_path, face_list[idx]))

            pred_face = torch.tensor(pred_face).squeeze().to('cuda')
            pred_jaw = pred_face[:, :3]
            pred_face = pred_face[:, 3:]

            for i in range(num_sample):
                pred_res = np.load(os.path.join(gesture_path,gesture_list[idx]))
                pred = torch.tensor(pred_res).squeeze().to('cuda')

                if pred.shape[0] < pred_face.shape[0]:
                    repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
                    pred = torch.cat([pred, repeat_frame], dim=0)
                else:
                    pred = pred[:pred_face.shape[0], :]

                # pred = torch.cat([pred, pred_face], dim=-1)
                pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)

                pred = part2full(pred, stand)

                result_list.append(pred)

            vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)

            result_list = [res.to('cpu') for res in result_list]
            dict = np.concatenate(result_list[1:], axis=0)
            file_name = 'visualise/video/' + config.Log.name + '/' + \
                        cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
            np.save(file_name, dict)

            rendertool._render_sequences(cur_wav_file, vertices_list[1:], stand=stand, face=face)

def main():
    parser = parse_args()
    args = parser.parse_args()
    device = torch.device(args.gpu)
    torch.cuda.set_device(device)

    config = load_JsonConfig(args.config_file)

    face_model_name = args.face_model_name
    face_model_path = args.face_model_path
    body_model_name = args.body_model_name
    body_model_path = args.body_model_path
    smplx_path = './visualise/'

    os.environ['smplx_npz_path'] = config.smplx_npz_path
    os.environ['extra_joint_path'] = config.extra_joint_path
    os.environ['j14_regressor_path'] = config.j14_regressor_path

    print('init model...')
    generator = init_model(body_model_name, body_model_path, args, config)
    generator2 = None
    generator_face = init_model(face_model_name, face_model_path, args, config)
    print('init dataloader...')
    infer_set, infer_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)

    print('init smlpx model...')
    dtype = torch.float64
    model_params = dict(model_path=smplx_path,
                        model_type='smplx',
                        create_global_orient=True,
                        create_body_pose=True,
                        create_betas=True,
                        num_betas=300,
                        create_left_hand_pose=True,
                        create_right_hand_pose=True,
                        use_pca=False,
                        flat_hand_mean=False,
                        create_expression=True,
                        num_expression_coeffs=100,
                        num_pca_comps=12,
                        create_jaw_pose=True,
                        create_leye_pose=True,
                        create_reye_pose=True,
                        create_transl=False,
                        # gender='ne',
                        dtype=dtype, )
    smplx_model = smpl.create(**model_params).to('cuda')

    if args.rename != None:
        config.Log.name = args.rename
    print('init rendertool...')
    rendertool = RenderTool('visualise/video/' + config.Log.name)

    infer(config.Data.data_root, generator, generator_face, generator2, args.exp_name, infer_loader, infer_set, device,
          norm_stats, True, smplx_model, rendertool, args, config, face_path=args.face_path, gesture_path=args.gesture_path)

if __name__ == '__main__':
    main()
TashvikDhamija commented 1 month ago

Hello, I get similar results to @jameskuma

I tried to understand if there is a mismatch in parameters in DiffSHEG output and SHOW SMPLX model input but everything seems okay. Has anyone been able to find the right way to render SHOW results?

@JeremyCJM I tried running your code but I cannot figure out what the face_path and gesture_path are since the DiffSHEG model only gives one npy output. Also, not quite sure why it creates a dataset and loader for the whole talkSHOW dataset whilst infering one output. Can you help me use your code for a single inference from the .npy output DiffSHEG gives?

Any help in visualising would be appreciated!

https://github.com/user-attachments/assets/08c0cd0e-f1ed-43eb-94a5-1ed136528e32

Here is my code:

import os
import sys
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
sys.path.append(os.getcwd())

from transformers import Wav2Vec2Processor
from glob import glob

import numpy as np
import json
import smplx as smpl

from nets import *
from trainer.options import parse_args
from data_utils import torch_data
from trainer.config import load_JsonConfig

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
from visualise.rendering import RenderTool

global device
device = 'cpu'

def init_model(model_name, model_path, args, config):
    if model_name == 's2g_face':
        generator = s2g_face(
            args,
            config,
        )
    elif model_name == 's2g_body_vq':
        generator = s2g_body_vq(
            args,
            config,
        )
    elif model_name == 's2g_body_pixel':
        generator = s2g_body_pixel(
            args,
            config,
        )
    elif model_name == 's2g_LS3DCG':
        generator = LS3DCG(
            args,
            config,
        )
    else:
        raise NotImplementedError

    model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
    if model_name == 'smplx_S2G':
        generator.generator.load_state_dict(model_ckpt['generator']['generator'])

    elif 'generator' in list(model_ckpt.keys()):
        generator.load_state_dict(model_ckpt['generator'])
    else:
        model_ckpt = {'generator': model_ckpt}
        generator.load_state_dict(model_ckpt)

    return generator

def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
    vertices_list = []
    poses_list = []
    expression = torch.zeros([1, 50])

    for i in result_list:
        vertices = []
        poses = []
        for j in range(i.shape[0]):
            output = smplx_model(betas=betas,
                                 expression=i[j][165:265].unsqueeze_(dim=0),
                                 jaw_pose=i[j][0:3].unsqueeze_(dim=0),
                                 leye_pose=i[j][3:6].unsqueeze_(dim=0),
                                 reye_pose=i[j][6:9].unsqueeze_(dim=0),
                                 global_orient=i[j][9:12].unsqueeze_(dim=0),
                                 body_pose=i[j][12:75].unsqueeze_(dim=0),
                                 left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
                                 right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
                                 return_verts=True)
            vertices.append(output.vertices.detach().cpu().numpy().squeeze())
            # pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1)
            pose = output.body_pose
            poses.append(pose.detach().cpu())
        vertices = np.asarray(vertices)
        vertices_list.append(vertices)
        poses = torch.cat(poses, dim=0)
        poses_list.append(poses)
    if require_pose:
        return vertices_list, poses_list
    else:
        return vertices_list, None

global_orient = torch.tensor([3.0747, -0.0158, -0.0152])

def infer(g_body, g_face, smplx_model, rendertool, config, args):
    betas = torch.zeros([1, 300], dtype=torch.float64).to(device)
    am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
    am_sr = 16000
    num_sample = args.num_sample
    cur_wav_file = args.audio_file
    id = args.id
    face = args.only_face
    stand = args.stand
    if face:
        body_static = torch.zeros([1, 162], device=device)
        body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)

    # result_list = []

    # pred_face = g_face.infer_on_audio(cur_wav_file,
    #                                   initial_pose=None,
    #                                   norm_stats=None,
    #                                   w_pre=False,
    #                                   # id=id,
    #                                   frame=None,
    #                                   am=am,
    #                                   am_sr=am_sr
    #                                   )
    # pred_face = torch.tensor(pred_face).squeeze().to(device)
    # # pred_face = torch.zeros([gt.shape[0], 105])

    # if config.Data.pose.convert_to_6d:
    #     pred_jaw = pred_face[:, :6].reshape(pred_face.shape[0], -1, 6)
    #     pred_jaw = matrix_to_axis_angle(rotation_6d_to_matrix(pred_jaw)).reshape(pred_face.shape[0], -1)
    #     pred_face = pred_face[:, 6:]
    # else:
    #     pred_jaw = pred_face[:, :3]
    #     pred_face = pred_face[:, 3:]

    # id = torch.tensor([id], device=device)

    # for i in range(num_sample):
    #     pred_res = g_body.infer_on_audio(cur_wav_file,
    #                                      initial_pose=None,
    #                                      norm_stats=None,
    #                                      txgfile=None,
    #                                      id=id,
    #                                      var=None,
    #                                      fps=30,
    #                                      w_pre=False
    #                                      )
    #     pred = torch.tensor(pred_res).squeeze().to(device)

    #     if pred.shape[0] < pred_face.shape[0]:
    #         repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
    #         pred = torch.cat([pred, repeat_frame], dim=0)
    #     else:
    #         pred = pred[:pred_face.shape[0], :]

    #     body_or_face = False
    #     if pred.shape[1] < 275:
    #         body_or_face = True
    #     if config.Data.pose.convert_to_6d:
    #         pred = pred.reshape(pred.shape[0], -1, 6)
    #         pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
    #         pred = pred.reshape(pred.shape[0], -1)

    #     if config.Model.model_name == 's2g_LS3DCG':
    #         pred = torch.cat([pred[:, :3], pred[:, 103:], pred[:, 3:103]], dim=-1)
    #     else:
    #         pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)

    #     # pred[:, 9:12] = global_orient
    #     pred = part2full(pred, stand)
    #     if face:
    #         pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
    #     # result_list[0] = poses2pred(result_list[0], stand)
    #     # if gt_0 is None:
    #     #     gt_0 = gt
    #     # pred = pred2poses(pred, gt_0)
    #     # result_list[0] = poses2poses(result_list[0], gt_0)

    #     result_list.append(pred)

    result_list = torch.from_numpy(np.load('../DiffSHEG/results/talkshow_88/test_custom_audio/talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree/fixStart10/ckpt_e2599_ddim25_lastStepInterp/pid_4/gesture/Forrest_tts.npy'))
    result_list = part2full(result_list[0], stand=True).unsqueeze(0)
    print(result_list.shape)
    vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)

    result_list = [res.to('cpu') for res in result_list]
    dict = np.concatenate(result_list[:], axis=0)
    file_name = 'visualise/video/' + config.Log.name + '/' + \
                cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
    np.save(file_name, dict)
    rendertool._render_sequences(cur_wav_file, vertices_list, stand=stand, face=face, whole_body=args.whole_body)

def main():
    parser = parse_args()
    args = parser.parse_args()
    # device = torch.device(args.gpu)
    # torch.cuda.set_device(device)

    config = load_JsonConfig(args.config_file)

    face_model_name = args.face_model_name
    face_model_path = args.face_model_path
    body_model_name = args.body_model_name
    body_model_path = args.body_model_path
    smplx_path = './visualise/'

    os.environ['smplx_npz_path'] = config.smplx_npz_path
    os.environ['extra_joint_path'] = config.extra_joint_path
    os.environ['j14_regressor_path'] = config.j14_regressor_path

    print('init model...')
    generator = init_model(body_model_name, body_model_path, args, config)
    generator2 = None
    generator_face = init_model(face_model_name, face_model_path, args, config)

    print('init smlpx model...')
    dtype = torch.float64
    model_params = dict(model_path=smplx_path,
                        model_type='smplx',
                        create_global_orient=True,
                        create_body_pose=True,
                        create_betas=True,
                        num_betas=300,
                        create_left_hand_pose=True,
                        create_right_hand_pose=True,
                        use_pca=False,
                        flat_hand_mean=False,
                        create_expression=True,
                        num_expression_coeffs=100,
                        num_pca_comps=12,
                        create_jaw_pose=True,
                        create_leye_pose=True,
                        create_reye_pose=True,
                        create_transl=False,
                        # gender='ne',
                        dtype=dtype, )
    smplx_model = smpl.create(**model_params).to(device)
    print('init rendertool...')
    rendertool = RenderTool('visualise/video/' + config.Log.name)

    infer(generator, generator_face, smplx_model, rendertool, config, args)

if __name__ == '__main__':
    main()