Wenchao-M / SAGE

Official implementation of paper "Stratified Avatar Generation from Sparse Observations"
Other
18 stars 0 forks source link

real-time inference #3

Open dnpcs12 opened 1 month ago

dnpcs12 commented 1 month ago

Hello, thank you for sharing your excellent project! I wanted to use the model in real-time, so I wrote the following code to perform real-time inference using the model you provided. However, the frame rate was much slower, around 0.08 ~ 0.1 s per frame, which is much slower than the 0.74ms mentioned in the paper. Did I do something wrong? I set the window size to 20 and tested it on a Nvidia RTX 3090. image

class InferenceModel():
    def __init__(self,args = None):
        if args is None:
            cfg_args = get_args()
            cfg_args.cfg = 'config_decoder/refiner_FULL.yaml'
            args = merge_file(cfg_args)
            name = cfg_args.cfg.split('/')[-1].split('.')[0]  # output directory
            args.SAVE_DIR = os.path.join("outputs", name)
        self.args = args
        torch.backends.cudnn.benchmark = False
        random.seed(args.SEED)
        np.random.seed(args.SEED)
        torch.manual_seed(args.SEED)
        fps = args.FPS  # AMASS dataset requires 60 frames per second
        body_model = BodyModel(args.SUPPORT_DIR).to(device)

        # Load VQVAE model
        vqcfg = args.VQVAE
        self.vq_model_upper = TransformerVQVAE(in_dim=len(upper_body) * 6, n_layers=vqcfg.n_layers, hid_dim=vqcfg.hid_dim,
                                          heads=vqcfg.heads, dropout=vqcfg.dropout, n_codebook=vqcfg.n_codebook,
                                          n_e=vqcfg.n_e, e_dim=vqcfg.e_dim, beta=vqcfg.beta)
        self.vq_model_lower = TransformerVQVAE(in_dim=len(lower_body) * 6, n_layers=vqcfg.n_layers, hid_dim=vqcfg.hid_dim,
                                          heads=vqcfg.heads, dropout=vqcfg.dropout, n_codebook=vqcfg.n_codebook,
                                          n_e=vqcfg.n_e, e_dim=vqcfg.e_dim, beta=vqcfg.beta)

        # Load Diffusion model
        self.diff_model_upper = MotionDiffusion(cfg=args.DIFFUSION, input_length=args.INPUT_MOTION_LENGTH,
                                           num_layers=args.DIFFUSION.layers_upper, use_upper=False).to(device)
        self.diff_model_lower = MotionDiffusion(cfg=args.DIFFUSION, input_length=args.INPUT_MOTION_LENGTH,
                                           num_layers=args.DIFFUSION.layers_lower, use_upper=True).to(device)
        self.decoder_model = TransformerDecoder(in_dim=132, seq_len=args.INPUT_MOTION_LENGTH, **args.DECODER).to(device)
        self.refineNet = Refinenet(n_layers=args.REFINER.n_layers, hidder_dim=args.REFINER.hidden_dim).to(device)

        # Upper VQVAE weight
        upper_vq_dir = args.UPPER_VQ_DIR
        vqvae_upper_file = os.path.join(upper_vq_dir, 'best.pth.tar')
        if os.path.exists(vqvae_upper_file):
            checkpoint_upper = torch.load(vqvae_upper_file, map_location=lambda storage, loc: storage)
            self.vq_model_upper.load_state_dict(checkpoint_upper)
            print(f"=> Load upper vqvae {vqvae_upper_file}")
        else:
            print("No upper vqvae model!")
            return

        # Lower VQVAE weight
        lower_vq_dir = args.LOWER_VQ_DIR
        vqvae_lower_file = os.path.join(lower_vq_dir, 'best.pth.tar')
        if os.path.exists(vqvae_lower_file):
            checkpoint_lower = torch.load(vqvae_lower_file, map_location=lambda storage, loc: storage)
            self.vq_model_lower.load_state_dict(checkpoint_lower)
            print(f"=> Load upper vqvae {vqvae_lower_file}")
        else:
            print("No lower vqvae model!")
            return

        decoder_dir = args.DECODER_DIR
        decoder_file = os.path.join(decoder_dir, 'best.pth.tar')
        if os.path.exists(decoder_file):
            checkpoint_all = torch.load(decoder_file, map_location=lambda storage, loc: storage)
            self.diff_model_upper.load_state_dict(checkpoint_all['upper_state_dict'])
            self.diff_model_lower.load_state_dict(checkpoint_all['lower_state_dict'])
            self.decoder_model.load_state_dict(checkpoint_all['decoder_state_dict'])
            print("=> loading checkpoint '{}'".format(decoder_file))
        else:
            print("decoder file not exist!")
            return

        if hasattr(args, 'REFINER_DIR'):
            ourput_dir = args.REFINER_DIR
        else:
            ourput_dir = args.SAVE_DIR

        refine_file = os.path.join(ourput_dir, 'best.pth.tar')
        if os.path.exists(refine_file):
            print("=> loading refine model '{}'".format(refine_file))
            refine_checkpoint = torch.load(refine_file, map_location=lambda storage, loc: storage)
            self.refineNet.load_state_dict(refine_checkpoint)
        else:
            print(f"{refine_file} not exist!!!")
            return

        # torch.save(refine_checkpoint['state_dict'], refine_file)
        self.vq_model_upper.eval()
        self.vq_model_lower.eval()
        self.diff_model_upper.eval()
        self.diff_model_lower.eval()
        self.decoder_model.eval()
        self.refineNet.eval()

    def inference(self, sparse):
        '''
        :param sparse: (20,54)
        :return:
        '''
        #print("sparse shape:",sparse.shape)
        # padding the sequence with the first frame
        sparse = sparse.reshape(1,sparse.shape[0],3,18).to(device).cuda().float()
        # inference
        with torch.no_grad():
            upper_latents = self.diff_model_upper.diffusion_reverse(sparse) #(bs, seq, 3, 18)
            lower_latents = self.diff_model_lower.diffusion_reverse(sparse, upper_latents)
            recover_6d = self.decoder_model(upper_latents, lower_latents, sparse)
        sample = recover_6d[:, -1].reshape(-1, 22 * 6)

        initial_res = sample#torch.cat(sample, dim=0)
        pred_delta, hid = self.refineNet(initial_res[None], None)
        final_pred = pred_delta.squeeze() + initial_res
        #print("final_pred:", final_pred.shape)
        return final_pred.cpu()

def udp_server():
    model = InferenceModel()
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    portNum = 7868
    ip= "0.0.0.0"
    sock.bind((ip, portNum))

    while True:
        data, addr = sock.recvfrom(10240)
        received_data = data.decode('utf-8')

        input_data = process_quan.process_received_data(received_data)
        start_time = time.time()
        output = model.inference(input_data)
        output = utils_transform.sixd2quat(output.reshape(-1,6)).reshape(1,-1)
        print(str(time.time() - start_time) + " seconds")
        p = output[0]
        t = [0,0,0]
        s = ','.join(['%g' % v for v in p]) + '#' + \
            ','.join(['%g' % v for v in t]) + '$'

        #print(f"Received message: {data} from {addr}")

        sock.sendto(s.encode("utf8"), addr)

if __name__ == "__main__":
    udp_server()
Wenchao-M commented 1 month ago

Hi, thank you for your interest in our work. I reviewed our evaluation code for speed calculation and compared it with the inference code you provided. The main difference appears to be that we concatenate multiple sequences in a batch while maintaining a sliding window approach. Specifically, we use a sliding window size of 20 for per-frame output, creating multiple sliding windows as a batch, which are then fed into our network. This allows us to fully utilize the GPU memory for inference, albeit with a slight latency.