ViTAE-Transformer / ViTPose

The official repo for [NeurIPS'22] "ViTPose: Simple Vision Transformer Baselines for Human Pose Estimation" and [TPAMI'23] "ViTPose++: Vision Transformer for Generic Body Pose Estimation"
Apache License 2.0
1.38k stars 187 forks source link

KeyError: 'dataset_idx' #69

Open MaxTeselkin opened 1 year ago

MaxTeselkin commented 1 year ago

Hi! I am trying to use ViTPose+ basic for inference using the following code:

# importing necessary libraries
import warnings
warnings.filterwarnings('ignore')
import torch
import torchvision
import cv2
from mmpose.apis import (inference_top_down_pose_model,
                         init_pose_model,
                         vis_pose_result,
                         process_mmdet_results)
from mmdet.apis import inference_detector, init_detector
from IPython.display import Image, display
import tempfile
import os

# define model configs and checkpoints
pose_config = '/kaggle/working/ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vitPose+_base_coco+aic+mpii+ap10k+apt36k+wholebody_256x192_udp.py'
pose_checkpoint = '/kaggle/input/vitpose-plus-basic/vitpose-plus-b.pth'
det_config = '/kaggle/working/ViTPose/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py'
det_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'

# initialize pose model
pose_model = init_pose_model(pose_config, pose_checkpoint, device='cpu')
# initialize detector
det_model = init_detector(det_config, det_checkpoint, device='cpu')

img = '/kaggle/working/ViTPose/tests/data/coco/000000196141.jpg'

# inference detection
mmdet_results = inference_detector(det_model, img)

# extract person (COCO_ID=1) bounding boxes from the detection results
person_results = process_mmdet_results(mmdet_results, cat_id=1)

# inference pose
pose_results, returned_outputs = inference_top_down_pose_model(pose_model,
                                                               img,
                                                               person_results,
                                                               bbox_thr=0.3,
                                                               format='xyxy',
                                                               dataset=pose_model.cfg.data.test.type)

# show pose estimation results
vis_result = vis_pose_result(pose_model,
                             img,
                             pose_results,
                             dataset=pose_model.cfg.data.test.type,
                             show=False)

with tempfile.TemporaryDirectory() as tmpdir:
    file_name = os.path.join(tmpdir, 'pose_results.png')
    cv2.imwrite(file_name, vis_result)
    display(Image(file_name))

But I get the following error:

KeyError                                  Traceback (most recent call last)
/tmp/ipykernel_27/4115304742.py in <module>
     24                                                                bbox_thr=0.3,
     25                                                                format='xyxy',
---> 26                                                                dataset=pose_model.cfg.data.test.type)
     27 
     28 # show pose estimation results

/kaggle/working/ViTPose/mmpose/apis/inference.py in inference_top_down_pose_model(model, img_or_path, person_results, bbox_thr, format, dataset, dataset_info, return_heatmap, outputs)
    404             dataset=dataset,
    405             dataset_info=dataset_info,
--> 406             return_heatmap=return_heatmap)
    407 
    408         if return_heatmap:

/kaggle/working/ViTPose/mmpose/apis/inference.py in _inference_single_pose_model(model, img_or_path, bboxes, dataset, dataset_info, return_heatmap)
    276             data['image_file'] = img_or_path
    277 
--> 278         data = test_pipeline(data)
    279         batch_data.append(data)
    280 

/kaggle/working/ViTPose/mmpose/datasets/pipelines/shared_transform.py in __call__(self, data)
     97         """
     98         for t in self.transforms:
---> 99             data = t(data)
    100             if data is None:
    101                 return None

/kaggle/working/ViTPose/mmpose/datasets/pipelines/shared_transform.py in __call__(self, results)
    166                 else:
    167                     key_src = key_tgt = key
--> 168                 meta[key_tgt] = results[key_src]
    169         if 'bbox_id' in results:
    170             meta['bbox_id'] = results['bbox_id']

KeyError: 'dataset_idx'
Serdnad commented 1 year ago

Ran into the same issue, I'm guessing it has to do with dataset_idx being populated in the configs. I don't know what the proper fix is, but I found a temporary workaround replacing

img_sources = torch.from_numpy(np.array([ele['dataset_idx'] for ele in img_metas])).to(img.device)

with

img_sources = torch.from_numpy(np.array([0 for ele in img_metas])).to(img.device)

in the 2 places it's used by mmpose. That said, I modified the config to only train on COCO, and I'm not sure my workaround would work otherwise.

XiongFenghhh commented 10 months ago

Ran into the same issue, I'm guessing it has to do with dataset_idx being populated in the configs. I don't know what the proper fix is, but I found a temporary workaround replacing

img_sources = torch.from_numpy(np.array([ele['dataset_idx'] for ele in img_metas])).to(img.device)

with

img_sources = torch.from_numpy(np.array([0 for ele in img_metas])).to(img.device)

in the 2 places it's used by mmpose. That said, I modified the config to only train on COCO, and I'm not sure my workaround would work otherwise.

Hello, which file did you modify? I run into the same problem