BNU-IVC / FastPoseGait

FastPoseGait is a user-friendly and flexible repository that aims to help researchers get started on pose-based gait recognition quickly.
64 stars 3 forks source link

Usage for inference? #7

Closed louis030195 closed 10 months ago

louis030195 commented 1 year ago

Hey, thanks for the great work!

Would it be possible to explain how to use pre-trained models for inference? For example with a .mp4 file. I've been fighting with the transformation, still not right:

import cv2
from modeling import models
from utils import config_loader, get_msg_mgr
from data.transform import GaitTR_MultiInput, SkeletonInput
import argparse
import os
import torch.distributed as dist
import torch
import numpy as np

# ... some code ...

def infer_video(video_path):
    video = cv2.VideoCapture(video_path)

    results = []
    while video.isOpened():
        ret, frame = video.read()
        print('frame shape: ', np.shape(frame))
        if not ret:
            break

        # Prepare the input for the model
        labs = torch.zeros(1)  # Dummy labels
        seqL = torch.zeros(1)  # Dummy sequence length
        gait_transform = GaitTR_MultiInput(joint_format='coco')
        skeleton_transform = SkeletonInput()

        # Apply the transformations to your frame
        frame = gait_transform(frame)
        frame = skeleton_transform(frame)
        ipts = (torch.from_numpy(frame),)  # Convert frame to tensor
        ipts = (ipts[0].unsqueeze(0),)  # Add an extra dimension at position 0

        # Run inference on the frame
        result = model((ipts, labs, None, None, seqL))
        results.append(result)

    video.release()

    # Process and return the results
    return results

# ... some code ...

My transformation is still wrong here, would love your feedback on how to properly infer on a video!

Would be super helpful Thanks 🙏

DreamShibei commented 1 year ago

Hi! Is your video in RGB format? Our transformation only supports the input format of [T, V, C] , where T is the number of frames, V is the number of keypoints and C is the coordinate.