open-mmlab / mmpose

OpenMMLab Pose Estimation Toolbox and Benchmark.
https://mmpose.readthedocs.io/en/latest/
Apache License 2.0
5.6k stars 1.22k forks source link

Use_udp error: cannot convert float NaN to integer #1273

Open spoonbobo opened 2 years ago

spoonbobo commented 2 years ago

Hi folks, I used MMPose to train on CarFusion dataset using HRNet with UDP implemented (hrnet_w32_coco_384x288_udp). The error ValueError: cannot convert float NaN to integer happened as I passed the following code in the training validation pipeline. dict(type='TopDownAffine', use_udp=True) Without use_udp=True (False), the training is working fine.

Training config:

joint_weights = [1, 1, 1, 1, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8]
sigmas = [0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.01, 0.062, 0.062, 0.107, 0.107, 0.01]

keypoint_info = dict({
            # bottom keypoints
            0: dict(name='bottom_front_right', id=0, color=[30, 252, 100], type='lower', swap='bottom_front_left'),
            1: dict(name='bottom_front_left', id=1, color=[30, 252, 100], type='lower', swap='bottom_front_right'),
            2: dict(name='bottom_back_right', id=2, color=[30, 252, 100], type='lower', swap='bottom_back_left'),
            3: dict(name='bottom_back_left', id=3, color=[30, 252, 100], type='lower', swap='bottom_back_right'),

            # middle keypoints
            4: dict(name='middle_front_right', id=4, color=[255, 246, 1], type='upper', swap='middle_front_left'),
            5: dict(name='middle_front_left', id=5, color=[255, 246, 1], type='upper', swap='middle_front_right'),
            6: dict(name='middle_back_right', id=6, color=[250, 84, 124], type='upper', swap='middle_back_left'),
            7: dict(name='middle_back_left', id=7, color=[250, 84, 124], type='upper', swap='middle_back_right'),
            8: dict(name='ghost_8', id=8, color=[0, 0, 0], type='', swap=''),

            # top keypoints
            9: dict(name='top_front_right', id=9, color=[255, 246, 1], type='upper', swap='top_front_left'),
            10: dict(name='top_front_left', id=10, color=[255, 246, 1], type='upper', swap='top_front_right'),
            11: dict(name='top_back_right', id=11, color=[250, 84, 124], type='upper', swap='top_back_left'),
            12: dict(name='top_back_left', id=12, color=[250, 84, 124], type='upper', swap='top_back_right'),
            13: dict(name='ghost_13', id=13, color=[0, 0, 0], type='', swap='')
})

skeleton_info = dict({
            # bottom edges
            0: dict(link=('bottom_front_right', 'bottom_front_left'), id=0, color=[30, 252, 100]),
            1: dict(link=('bottom_back_right', 'bottom_back_left'), id=1, color=[30, 252, 100]),
            2: dict(link=('bottom_front_right', 'bottom_back_right'), id=2, color=[30, 252, 100]),
            3: dict(link=('bottom_front_left', 'bottom_back_left'), id=3, color=[30, 252, 100]),
            # front edges
            4: dict(link=('middle_front_left', 'middle_front_right'), id=4, color=[255, 246, 1]),
            5: dict(link=('bottom_front_left', 'middle_front_left'), id=5, color=[255, 246, 1]),
            6: dict(link=('bottom_front_right', 'middle_front_right'), id=6, color=[255, 246, 1]),
            7: dict(link=('middle_front_left', 'top_front_left'), id=7, color=[255, 246, 1]),
            8: dict(link=('middle_front_right', 'top_front_right'), id=8, color=[255, 246, 1]),
            # back edges
            9: dict(link=('middle_back_left', 'middle_back_right'), id=9, color=[250, 84, 124]),
            10: dict(link=('bottom_back_left', 'middle_back_left'), id=10, color=[250, 84, 124]),
            11: dict(link=('bottom_back_right', 'middle_back_right'), id=11, color=[250, 84, 124]),
            12: dict(link=('middle_back_left', 'top_back_left'), id=12, color=[250, 84, 124]),
            13: dict(link=('middle_back_right', 'top_back_right'), id=13, color=[250, 84, 124])

        })

dataset_info = dict(
    dataset_name='carfusion',
    paper_info=paper_info,
    keypoint_info=keypoint_info,
    skeleton_info=skeleton_info,
    joint_weights = joint_weights,
    sigmas=sigmas)

log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
checkpoint_config = dict(interval=10)
evaluation = dict(interval=10, metric='mAP', save_best='AP')
optimizer = dict(type='Adam', lr=0.0005)
optimizer_config = dict(grad_clip=None)
target_type = 'GaussianHeatmap'
work_dir = 'work_dirs/hrnet_w32_coco_tiny_256x192'
data_root = 'data/coco_tiny'
gpu_ids = range(0, 1)
seed = 0

lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=0.001,
    step=[170, 200])

# change it to 60.
total_epochs = 60

log_config = dict(interval=1, hooks=[dict(type='TextLoggerHook')])

channel_cfg = dict(
    num_output_channels=14,
    dataset_joints=14,
    dataset_channel=[[
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13
    ]],
    inference_channel=[
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13
    ])

model = dict(
    type='TopDown',
    pretrained=
    'https://download.openmmlab.com/mmpose/pretrain_models/hrnet_w32-36af842e.pth',
    backbone=dict(
        type='HRNet',
        in_channels=3,
        extra=dict(
            stage1=dict(
                num_modules=1,
                num_branches=1,
                block='BOTTLENECK',
                num_blocks=(4, ),
                num_channels=(64, )),
            stage2=dict(
                num_modules=1,
                num_branches=2,
                block='BASIC',
                num_blocks=(4, 4),
                num_channels=(32, 64)),
            stage3=dict(
                num_modules=4,
                num_branches=3,
                block='BASIC',
                num_blocks=(4, 4, 4),
                num_channels=(32, 64, 128)),
            stage4=dict(
                num_modules=3,
                num_branches=4,
                block='BASIC',
                num_blocks=(4, 4, 4, 4),
                num_channels=(32, 64, 128, 256)))),
    keypoint_head=dict(
        type='TopdownHeatmapSimpleHead',
        in_channels=32,
        out_channels=14,
        num_deconv_layers=0,
        extra=dict(final_conv_kernel=1),
        loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
    train_cfg=dict(),
    test_cfg=dict(
        flip_test=True,
        post_process='default',
        target_type=target_type,
        shift_heatmap=False,
        modulate_kernel=17,
        use_udp=True))

data_cfg = dict(
    image_size=[288, 384],  # 192 256
    heatmap_size=[72, 96],
    num_output_channels=14,
    num_joints=14,
    dataset_channel=[[
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13
    ]],
    inference_channel=[
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13
    ],
    soft_nms=False,
    nms_thr=1.0,
    oks_thr=0.9,
    vis_thr=0.2,
    use_gt_bbox=False,
    det_bbox_thr=0.0,
    bbox_file=
    'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json'
)

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='TopDownRandomFlip', flip_prob=0.5),
    dict(
        type='TopDownHalfBodyTransform',
        num_joints_half_body=8,
        prob_half_body=0.3),
    dict(
        type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
    dict(type='TopDownAffine', use_udp=True),
    dict(type='ToTensor'),
    dict(
        type='NormalizeTensor',
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]),
    dict(type='TopDownGenerateTarget', 
         sigma=3,
         encoding='UDP',
         target_type=target_type),
    dict(
        type='Collect',
        keys=['img', 'target', 'target_weight'],
        meta_keys=[
            'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
            'rotation', 'bbox_score', 'flip_pairs'
        ])
]

val_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='TopDownAffine', use_udp=True),
    dict(type='ToTensor'),
    dict(
        type='NormalizeTensor',
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]),
    dict(
        type='Collect',
        keys=['img'],
        meta_keys=[
            'image_file', 'center', 'scale', 'rotation', 'bbox_score',
            'flip_pairs'
        ])
]

test_pipeline = val_pipeline

data = dict(
    samples_per_gpu=64,
    workers_per_gpu=2,
    val_dataloader=dict(samples_per_gpu=32),
    test_dataloader=dict(samples_per_gpu=32),
    train=dict(
        type='TopDownCOCOTinyDataset',
        ann_file='data/coco_tiny/train.json',
        img_prefix='data/coco_tiny/images/',
        data_cfg=data_cfg,
        pipeline=train_pipeline,

        dataset_info=dict(
            dataset_name='coco',
            paper_info=paper_info,
            keypoint_info=keypoint_info,
            skeleton_info=skeleton_info,
            joint_weights = joint_weights,
            sigmas=sigmas)),

    val=dict(
        type='TopDownCOCOTinyDataset',
        ann_file='data/coco_tiny/val.json',
        img_prefix='data/coco_tiny/images/',
        data_cfg=data_cfg,
        pipeline=val_pipeline,

        dataset_info=dict(
            dataset_name='coco',
            paper_info=paper_info,
            keypoint_info=keypoint_info,
            skeleton_info=skeleton_info,
            joint_weights = joint_weights,
            sigmas=sigmas)),

    test=dict(
        type='TopDownCOCOTinyDataset',
        ann_file='data/coco_tiny/val.json',
        img_prefix='data/coco_tiny/images/',
        data_cfg=data_cfg,
        pipeline=test_pipeline,

        dataset_info=dict(
            dataset_name='coco',
            paper_info=paper_info,
            keypoint_info=keypoint_info,
            skeleton_info=skeleton_info,
            joint_weights = joint_weights,
            sigmas=sigmas)))
Dataset Config
import json
import os
import os.path as osp
from collections import OrderedDict
import numpy as np
from mmpose.core.evaluation.top_down_eval import (keypoint_nme,
                                                  keypoint_pck_accuracy)
from mmpose.datasets.builder import DATASETS
from ..base import Kpt2dSviewRgbImgTopDownDataset

@DATASETS.register_module(name="TopDownCarFusionDataset")
class TopDownCarFusionDataset(Kpt2dSviewRgbImgTopDownDataset):

    def __init__(self,
                 ann_file,
                 img_prefix,
                 data_cfg,
                 pipeline,
                 dataset_info=None,
                 test_mode=False):
        super().__init__(
            ann_file, img_prefix, data_cfg, pipeline, dataset_info, coco_style=False, test_mode=test_mode
        )

        # flip_pairs, upper_body_ids and lower_body_ids will be used
        # in some data augmentations like random flip
        self.ann_info['flip_pairs'] = [[0, 1], [2, 3], [4, 5], [6, 7], [9, 10], [11, 12]]
        # To be confirmed: suppose the middle-points should be upper
        self.ann_info['upper_ids'] = (0, 1, 2, 3)
        self.ann_info['lower_ids'] = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13)
        # Parameter Tunning
        self.ann_info['joint_weights'] = [1, 1, 1, 1, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8]
        self.ann_info['use_different_joint_weights'] = False

        self.dataset_name = 'carfusion'
        self.db = self._get_db()

    def _get_db(self):
        with open(self.ann_file) as f:
            anns = json.load(f)

        db = []
        for idx, ann in enumerate(anns):
            # get image path
            image_file = osp.join(self.img_prefix, ann['image_file'])
            # get bbox
            bbox = ann['bbox']
            center, scale = self._xywh2cs(*bbox)
            # get keypoints
            keypoints = np.array(
                ann['keypoints'], dtype=np.float32).reshape(-1, 3)
            num_joints = keypoints.shape[0]
            joints_3d = np.zeros((num_joints, 3), dtype=np.float32)
            joints_3d[:, :2] = keypoints[:, :2]
            joints_3d_visible = np.zeros((num_joints, 3), dtype=np.float32)
            joints_3d_visible[:, :2] = np.minimum(1, keypoints[:, 2:3])

            sample = {
                'image_file': image_file,
                'center': center,
                'scale': scale,
                'bbox': bbox,
                'rotation': 0,
                'joints_3d': joints_3d,
                'joints_3d_visible': joints_3d_visible,
                'bbox_score': 1,
                'bbox_id': idx,
            }
            db.append(sample)

        return db

    def _xywh2cs(self, x, y, w, h):
        """This encodes bbox(x, y, w, h) into (center, scale)
        Args:
            x, y, w, h
        Returns:
            tuple: A tuple containing center and scale.
            - center (np.ndarray[float32](2,)): center of the bbox (x, y).
            - scale (np.ndarray[float32](2,)): scale of the bbox w & h.
        """
        aspect_ratio = self.ann_info['image_size'][0] / self.ann_info[
            'image_size'][1]
        center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
        if w > aspect_ratio * h:
            h = w * 1.0 / aspect_ratio
        elif w < aspect_ratio * h:
            w = h * aspect_ratio

        # pixel std is 200.0
        scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
        # padding to include proper amount of context
        scale = scale * 1.25
        return center, scale

    def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):
        """Evaluate keypoint detection results. The pose prediction results will
        be saved in `${res_folder}/result_keypoints.json`.

        Note:
        batch_size: N
        num_keypoints: K
        heatmap height: H
        heatmap width: W

        Args:
        outputs (list(preds, boxes, image_path, output_heatmap))
            :preds (np.ndarray[N,K,3]): The first two dimensions are
                coordinates, score is the third dimension of the array.
            :boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
                , scale[1],area, score]
            :image_paths (list[str]): For example, ['Test/source/0.jpg']
            :output_heatmap (np.ndarray[N, K, H, W]): model outputs.

        res_folder (str): Path of directory to save the results.
        metric (str | list[str]): Metric to be performed.
            Options: 'PCK', 'NME'.

        Returns:
            dict: Evaluation results for evaluation metric.
        """
        metrics = metric if isinstance(metric, list) else [metric]
        allowed_metrics = ['PCK', 'NME']
        for metric in metrics:
            if metric not in allowed_metrics:
                raise KeyError(f'metric {metric} is not supported')

        res_file = os.path.join(res_folder, 'result_keypoints.json')

        kpts = []
        for output in outputs:
            preds = output['preds']
            boxes = output['boxes']
            image_paths = output['image_paths']
            bbox_ids = output['bbox_ids']

            batch_size = len(image_paths)
            for i in range(batch_size):
                kpts.append({
                    'keypoints': preds[i].tolist(),
                    'center': boxes[i][0:2].tolist(),
                    'scale': boxes[i][2:4].tolist(),
                    'area': float(boxes[i][4]),
                    'score': float(boxes[i][5]),
                    'bbox_id': bbox_ids[i]
                })
        kpts = self._sort_and_unique_bboxes(kpts)

        self._write_keypoint_results(kpts, res_file)
        info_str = self._report_metric(res_file, metrics)
        name_value = OrderedDict(info_str)

        return name_value

    def _report_metric(self, res_file, metrics, pck_thr=0.3):
        """Keypoint evaluation.

        Args:
        res_file (str): Json file stored prediction results.
        metrics (str | list[str]): Metric to be performed.
            Options: 'PCK', 'NME'.
        pck_thr (float): PCK threshold, default: 0.3.

        Returns:
        dict: Evaluation results for evaluation metric.
        """
        info_str = []

        with open(res_file, 'r') as fin:
            preds = json.load(fin)
        assert len(preds) == len(self.db)

        outputs = []
        gts = []
        masks = []

        for pred, item in zip(preds, self.db):
            outputs.append(np.array(pred['keypoints'])[:, :-1])
            gts.append(np.array(item['joints_3d'])[:, :-1])
            masks.append((np.array(item['joints_3d_visible'])[:, 0]) > 0)

        outputs = np.array(outputs)
        gts = np.array(gts)
        masks = np.array(masks)

        normalize_factor = self._get_normalize_factor(gts)

        if 'PCK' in metrics:
            _, pck, _ = keypoint_pck_accuracy(outputs, gts, masks, pck_thr,
                                              normalize_factor)
            info_str.append(('PCK', pck))

        if 'NME' in metrics:
            info_str.append(
                ('NME', keypoint_nme(outputs, gts, masks, normalize_factor)))

        return info_str

    @staticmethod
    def _write_keypoint_results(keypoints, res_file):
        """Write results into a json file."""

        with open(res_file, 'w') as f:
            json.dump(keypoints, f, sort_keys=True, indent=4)

    @staticmethod
    def _sort_and_unique_bboxes(kpts, key='bbox_id'):
        """sort kpts and remove the repeated ones."""
        kpts = sorted(kpts, key=lambda x: x[key])
        num = len(kpts)
        for i in range(num - 1, 0, -1):
            if kpts[i][key] == kpts[i - 1][key]:
                del kpts[i]

        return kpts

    @staticmethod
    def _get_normalize_factor(gts):
        """Get inter-ocular distance as the normalize factor, measured as the
        Euclidean distance between the outer corners of the eyes.

        Args:
            gts (np.ndarray[N, K, 2]): Groundtruth keypoint location.

        Return:
            np.ndarray[N, 2]: normalized factor
        """

        interocular = np.linalg.norm(
            gts[:, 0, :] - gts[:, 1, :], axis=1, keepdims=True)
        return np.tile(interocular, [1, 2])

Environment

sys.platform: linux
Python: 3.8.12 | packaged by conda-forge | (default, Oct 12 2021, 21:59:51) [GCC 9.4.0]
CUDA available: True
GPU 0,1,2,3: Tesla V100-DGXS-32GB
CUDA_HOME: /usr/local/cuda
NVCC: Build cuda_11.6.r11.6/compiler.30794723_0
GCC: gcc (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
PyTorch: 1.11.0a0+bfe5ad2
PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2019.0.5 Product Build 20190808 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.2.3 (Git Hash N/A)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.6
  - NVCC architecture flags: -gencode;arch=compute_52,code=sm_52;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_86,code=compute_86
  - CuDNN 8.3.2  (built against CUDA 11.5)
  - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.6, CUDNN_VERSION=8.3.2, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS=-fno-gnu-unique -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=ON, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 

TorchVision: 0.12.0a0
OpenCV: 4.5.5
MMCV: 1.4.6
MMCV Compiler: GCC 9.3
MMCV CUDA Compiler: 11.6
MMPose: 0.23.0+d612605

Error Traceback

ValueError                                Traceback (most recent call last)
/tmp/ipykernel_3232/3683948197.py in <module>
     18 
     19 # train model
---> 20 train_model(
     21     model, datasets, cfg, distributed=False, validate=True, meta=dict())

/workspace/mmpose/mmpose/apis/train.py in train_model(model, dataset, cfg, distributed, validate, timestamp, meta)
    189     elif cfg.load_from:
    190         runner.load_checkpoint(cfg.load_from)
--> 191     runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

/workspace/mmcv/mmcv/runner/epoch_based_runner.py in run(self, data_loaders, workflow, max_epochs, **kwargs)
    125                     if mode == 'train' and self.epoch >= self._max_epochs:
    126                         break
--> 127                     epoch_runner(data_loaders[i], **kwargs)
    128 
    129         time.sleep(1)  # wait for some hooks like loggers to finish

/workspace/mmcv/mmcv/runner/epoch_based_runner.py in train(self, data_loader, **kwargs)
     45         self.call_hook('before_train_epoch')
     46         time.sleep(2)  # Prevent possible deadlock during epoch transition
---> 47         for i, data_batch in enumerate(self.data_loader):
     48             self._inner_iter = i
     49             self.call_hook('before_train_iter')

/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self)
    519             if self._sampler_iter is None:
    520                 self._reset()
--> 521             data = self._next_data()
    522             self._num_yielded += 1
    523             if self._dataset_kind == _DatasetKind.Iterable and \

/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self)
   1201             else:
   1202                 del self._task_info[idx]
-> 1203                 return self._process_data(data)
   1204 
   1205     def _try_put_index(self):

/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1227         self._try_put_index()
   1228         if isinstance(data, ExceptionWrapper):
-> 1229             data.reraise()
   1230         return data
   1231 

/opt/conda/lib/python3.8/site-packages/torch/_utils.py in reraise(self)
    436             # instantiate since we don't know how to
    437             raise RuntimeError(msg) from None
--> 438         raise exception
    439 
    440 

ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/workspace/mmpose/mmpose/datasets/datasets/base/kpt_2d_sview_rgb_img_top_down_dataset.py", line 274, in __getitem__
    return self.pipeline(results)
  File "/workspace/mmpose/mmpose/datasets/pipelines/shared_transform.py", line 99, in __call__
    data = t(data)
  File "/workspace/mmpose/mmpose/datasets/pipelines/top_down_transform.py", line 629, in __call__
    else:
  File "/workspace/mmpose/mmpose/datasets/pipelines/top_down_transform.py", line 484, in _udp_generate_target
    feat_stride = (image_size - 1.0) / (heatmap_size - 1.0)
ValueError: cannot convert float NaN to integer
spoonbobo commented 2 years ago

Hi, any folks have ideas how to solve this error? Thanks.

ly015 commented 2 years ago

Could you please check if image_size and heatmap_size are properly set in the ann_info in the data samples? Please check following references:

EEWenbinWu commented 2 years ago

Hi, any folks have ideas how to solve this error? Thanks.