fabro66 / GAST-Net-3DPoseEstimation

A Graph Attention Spatio-temporal Convolutional Networks for 3D Human Pose Estimation in Video (GAST-Net)
MIT License
313 stars 70 forks source link

How to obtain a JSON file similar to baseball.json #21

Closed guerrifrancesco closed 3 years ago

guerrifrancesco commented 3 years ago

Hi, I'd like to obtain a JSON file similar to baseball.json using a video. How did you obtain yours?

fabro66 commented 3 years ago

Hi~ Put the following two functions in the ".\lib\pose\hrnet\pose_estimation\gen_kpts.py" file to generate JSON file similar to baseball.json.

def round_list(input_list, decimals=3):
    dim = len(input_list)

    for i in range(dim):
        for j in range(len(input_list[i])):
            input_list[i][j] = round(input_list[i][j], decimals)

    return input_list

def generate_ntu_kpts_json(video_path, kpts_file):
    args = parse_args()
    reset_config(args)

    # Loading detector and pose model, initialize sort for track
    human_model = yolo_model()
    pose_model = model_load(cfg)
    people_sort = Sort()

    with torch.no_grad():
        cap = cv2.VideoCapture(video_path)
        video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # collect keypoints information
        kpts_info = dict()
        data = []

        for i in tqdm(range(video_length)):
            frame_info = {'frame_index': i + 1}

            ret, frame = cap.read()
            try:
                bboxs, scores = yolo_det(frame, human_model, confidence=args.thred_score)

                if bboxs is None or not bboxs.any():
                    print('No person detected!')
                    continue
                # Using Sort to track people
                people_track = people_sort.update(bboxs)

                # Track the first two people in the video and remove the ID
                if people_track.shape[0] == 1:
                    people_track_ = people_track[-1, :-1].reshape(1, 4)
                elif people_track.shape[0] >= 2:
                    people_track_ = people_track[-2:, :-1].reshape(2, 4)
                    people_track_ = people_track_[::-1]
                else:
                    skeleton = {'skeleton': [{'pose': [], 'score': [], 'bbox': []}]}
                    frame_info.update(skeleton)
                    data.append(frame_info)

                    continue

                track_bboxs = []
                for bbox in people_track_:
                    bbox = [round(i, 3) for i in list(bbox)]
                    track_bboxs.append(bbox)

            except Exception as e:
                print(e)
                continue

            # bbox is coordinate location
            inputs, origin_img, center, scale = PreProcess(frame, bboxs, cfg, args.num_person)
            inputs = inputs[:, [2, 1, 0]]
            if torch.cuda.is_available():
                inputs = inputs.cuda()
            output = pose_model(inputs.cuda())
            # compute coordinate
            preds, maxvals = get_final_preds(cfg, output.clone().cpu().numpy(), np.asarray(center),
                                             np.asarray(scale))

            skeleton = []
            for num, bbox in enumerate(track_bboxs):
                pose = preds[num].tolist()
                score = maxvals[num].tolist()
                pose = round_list(pose)
                score = round_list(score)

                one_skeleton = {'pose': pose,
                                'score': score,
                                'bbox': bbox}
                skeleton.append(one_skeleton)

            frame_info.update({'skeleton': skeleton})
            data.append(frame_info)

        kpts_info.update({'data': data})
        with open(kpts_file, 'w') as fw:
            json.dump(kpts_info, fw)
    print('Finishing!')
guerrifrancesco commented 3 years ago

Thank you!!

1) I put the functions in the file you said and I called from a python I made. I got an error:

lib/pose/hrnet/pose_estimation/gen_kpts.py", line 99, in generate_ntu_kpts_json inputs, origin_img, center, scale = PreProcess(frame, bboxs, cfg, args.num_pos) AttributeError: 'Namespace' object has no attribute 'num_pos'

What does it represent that attribute?

2) I noticed that the function searches for people in the frame and then it takes only the first two of them. Is it possible to remove this limit and obtain 2D skeleton of three or more people?

fabro66 commented 3 years ago

1). Replace "args.num_pos" with "args.num_person"

2). Yes. You can change it to obtain multi-person keypoints.