open-mmlab / mmpose

OpenMMLab Pose Estimation Toolbox and Benchmark.
https://mmpose.readthedocs.io/en/latest/
Apache License 2.0
5.7k stars 1.23k forks source link

[Bug] Try to train custom datasets #2368

Closed kdavidlp123 closed 1 year ago

kdavidlp123 commented 1 year ago

Prerequisite

Environment

OrderedDict([('sys.platform', 'win32'), ('Python', '3.8.16 (default, Mar 2 2023, 03:18:16) [MSC v.1916 64 bit (AMD64)]'), ('CUDA available', True), ('numpy_random_seed', 2147483648), ('GPU 0', 'NVIDIA GeForce RTX 3090'), ('CUDA_HOME', 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1'), ('NVCC', 'Cuda compilation tools, release 11.1, V11.1.105'), ('MSVC', 'Microsoft (R) C/C++ Optimizing Compiler Version 19.29.30148 for x64'), ('GCC', 'n/a'), ('PyTorch', '1.9.1+cu111'), ('PyTorch compiling details', 'PyTorch built with:\n - C++ Version: 199711\n - MSVC 192829337\n - Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications\n - Intel(R) MKL-DNN v2.1.2 (Git Hash 98be7e8afa711dc9b66c8ff3504129cb82013cdb)\n - OpenMP 2019\n - CPU capability usage: AVX2\n - CUDA Runtime 11.1\n - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37\n - CuDNN 8.0.5\n - Magma 2.5.4\n - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=C:/w/b/windows/tmp_bin/sccache-cl.exe, CXX_FLAGS=/DWIN32 /D_WINDOWS /GR /EHsc /w /bigobj -DUSE_PTHREADPOOL -openmp:experimental -IC:/w/b/windows/mkl/include -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DUSE_FBGEMM -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=OFF, USE_OPENMP=ON, \n'), ('TorchVision', '0.10.1+cu111'), ('OpenCV', '4.7.0'), ('MMEngine', '0.7.3'), ('MMPose', '1.0.0+2c4a60e')])

Reproduces the problem - code sample

For the \configs_base_\datasets\custom.py :

dataset_info = dict(
    dataset_name='custom',
    paper_info=dict(
        author='Lin, Tsung-Yi and Maire, Michael and '
        'Belongie, Serge and Hays, James and '
        'Perona, Pietro and Ramanan, Deva and '
        r'Doll{\'a}r, Piotr and Zitnick, C Lawrence',
        title='Microsoft coco: Common objects in context',
        container='European conference on computer vision',
        year='2014',
        homepage='http://cocodataset.org/',
    ),
    keypoint_info={
        0:
        dict(name='0', id=0, color=[51, 153, 255], type='', swap=''),
        1:
        dict(
            name='1',
            id=1,
            color=[51, 153, 255],
            type='',
            swap=''),
        2:
        dict(
            name='2',
            id=2,
            color=[51, 153, 255],
            type='',
            swap=''),
        3:
        dict(
            name='3',
            id=3,
            color=[51, 153, 255],
            type='',
            swap=''),
        4:
        dict(
            name='4',
            id=4,
            color=[51, 153, 255],
            type='',
            swap=''),
        5:
        dict(
            name='5',
            id=5,
            color=[0, 255, 0],
            type='',
            swap=''),
        6:
        dict(
            name='6',
            id=6,
            color=[255, 128, 0],
            type='',
            swap=''),
        7:
        dict(
            name='7',
            id=7,
            color=[0, 255, 0],
            type='',
            swap=''),
        8:
        dict(
            name='8',
            id=8,
            color=[255, 128, 0],
            type='',
            swap=''),
        9:
        dict(
            name='9',
            id=9,
            color=[0, 255, 0],
            type='',
            swap=''),
        10:
        dict(
            name='10',
            id=10,
            color=[255, 128, 0],
            type='',
            swap=''),
        11:
        dict(
            name='11',
            id=11,
            color=[0, 255, 0],
            type='',
            swap=''),
        12:
        dict(
            name='12',
            id=12,
            color=[255, 128, 0],
            type='',
            swap=''),
        13:
        dict(
            name='13',
            id=13,
            color=[0, 255, 0],
            type='',
            swap=''),
        14:
        dict(
            name='14',
            id=14,
            color=[255, 128, 0],
            type='',
            swap=''),
        15:
        dict(
            name='15',
            id=15,
            color=[0, 255, 0],
            type='',
            swap=''),
        16:
        dict(
            name='16',
            id=16,
            color=[255, 128, 0],
            type='',
            swap=''),
        17:
        dict(
            name='17',
            id=17,
            color=[255, 128, 0],
            type='',
            swap=''),
        18:
        dict(
            name='18',
            id=18,
            color=[255, 128, 0],
            type='',
            swap=''),
        19:
        dict(
            name='19',
            id=19,
            color=[255, 128, 0],
            type='',
            swap=''),
        20:
        dict(
            name='20',
            id=20,
            color=[255, 128, 0],
            type='',
            swap=''),
    },
    skeleton_info={
    },
    joint_weights=[
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.
    ],
    sigmas=[
        0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047, 0.047
    ])

For the configs\body_2d_keypoint\topdown_heatmap\coco\custom_model.py:

_base_ = ['../../../_base_/default_runtime.py',
          '../../../_base_/datasets/custom.py']

# runtime
train_cfg = dict(max_epochs=210, val_interval=10)

# optimizer
optim_wrapper = dict(optimizer=dict(
    type='Adam',
    lr=5e-4,
))

# learning policy
param_scheduler = [
    dict(
        type='LinearLR', begin=0, end=500, start_factor=0.001,
        by_epoch=False),  # warm-up
    dict(
        type='MultiStepLR',
        begin=0,
        end=210,
        milestones=[170, 200],
        gamma=0.1,
        by_epoch=True)
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=512)

# hooks
default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))

# codec settings
codec = dict(
    type='MSRAHeatmap',
    input_size=(288, 384),
    heatmap_size=(72, 96),
    sigma=3,
    unbiased=True)

channel_cfg = dict(
    num_output_channels = 21,
    dataset_joints = 21,
    dataset_channel=[
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]

    ],
    inference_channel=[
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20

    ])

# model settings
model = dict(
    type='TopdownPoseEstimator',
    data_preprocessor=dict(
        type='PoseDataPreprocessor',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        bgr_to_rgb=True),
    backbone=dict(
        type='ResNet',
        depth=50,
        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
    ),
    head=dict(
        type='HeatmapHead',
        in_channels=2048,
        out_channels=21,
        loss=dict(type='KeypointMSELoss', use_target_weight=True),
        decoder=codec),
    test_cfg=dict(
        flip_test=True,
        flip_mode='heatmap',
        shift_heatmap=True,
    ))

# base dataset settings
dataset_type = 'custom'
data_mode = 'topdown'
data_root = 'data/coco/'

# pipelines
train_pipeline = [
    dict(type='LoadImage', file_client_args={{_base_.file_client_args}}),
    dict(type='GetBBoxCenterScale'),
    dict(type='RandomFlip', direction='horizontal'),
    dict(type='RandomHalfBody'),
    dict(type='RandomBBoxTransform'),
    dict(type='TopdownAffine', input_size=codec['input_size']),
    dict(type='GenerateTarget', target_type='heatmap', encoder=codec),
    dict(type='PackPoseInputs')
]
val_pipeline = [
    dict(type='LoadImage', file_client_args={{_base_.file_client_args}}),
    dict(type='GetBBoxCenterScale'),
    dict(type='TopdownAffine', input_size=codec['input_size']),
    dict(type='PackPoseInputs')
]

# data loaders
train_dataloader = dict(
    batch_size=64,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_mode=data_mode,
        ann_file='training_json/training.json',
        data_prefix=dict(img='training/'),
        pipeline=train_pipeline,
    ))
# val_dataloader = None

val_dataloader = dict(
    batch_size=32,
    num_workers=2,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_mode=data_mode,
        ann_file='training_json/training.json',
        # bbox_file='data/coco/person_detection_results/'
        # 'COCO_val2017_detections_AP_H_56_person.json',
        data_prefix=dict(img='training/'),
        test_mode=True,
        pipeline=val_pipeline,
    ))
test_dataloader = val_dataloader

# evaluators
val_evaluator = dict(
    type='CocoMetric',
    ann_file=data_root + 'training_json/training.json')
test_evaluator = val_evaluator

Because I just want to run the script first, so I used training data as validation data.

Reproduces the problem - command or script

image

(openmmlab) PS D:\mmpose> python tools/train.py configs/body_2d_keypoint/topdown_heatmap/coco/custom_model.py Traceback (most recent call last): File "tools/train.py", line 160, in main() File "tools/train.py", line 142, in main cfg = Config.fromfile(args.config) File "C:\Users\user\anaconda3\envs\openmmlab\lib\site-packages\mmengine\config\config.py", line 178, in fromfile cfg_dict, cfg_text, env_variables = Config._file2dict( File "C:\Users\user\anaconda3\envs\openmmlab\lib\site-packages\mmengine\config\config.py", line 560, in _file2dict cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, File "C:\Users\user\anaconda3\envs\openmmlab\lib\site-packages\mmengine\config\config.py", line 423, in _substitute_base_vars cfg[k] = Config._substitute_base_vars( File "C:\Users\user\anaconda3\envs\openmmlab\lib\site-packages\mmengine\config\config.py", line 430, in _substitute_base_vars cfg = [ File "C:\Users\user\anaconda3\envs\openmmlab\lib\site-packages\mmengine\config\config.py", line 431, in Config._substitute_base_vars(c, base_var_dict, base_cfg) File "C:\Users\user\anaconda3\envs\openmmlab\lib\site-packages\mmengine\config\config.py", line 420, in _substitute_base_vars new_v = new_v[new_k] File "C:\Users\user\anaconda3\envs\openmmlab\lib\site-packages\mmengine\config\config.py", line 48, in missing raise KeyError(name) KeyError: 'file_client_args'

Reproduces the problem - error message

No

Additional information

Do I need to modify other documents for training on custom datasets?

And do I also need to add or register new dataset under mmpose\datasets\datasets\body? It seems a little different from the old version.

Thank you!!!

Ben-Louis commented 1 year ago

The file_client_args parameter is no longer in use and has been deprecated. If you're loading data directly from a disk, it's safe to remove this argument from your configuration.

In addition, it's recommended that you include metainfo in your dataset configuration. This can be done following the instructions provided in our user guide on preparing datasets, specifically in the section on using a custom dataset. Here is the link for your reference: https://mmpose.readthedocs.io/en/latest/user_guides/prepare_datasets.html#use-a-custom-dataset.

Ben-Louis commented 1 year ago

The 'area' refers to the number of pixels of the instance. The items 'segmentation' and 'area' are not necessary.

kdavidlp123 commented 1 year ago

image

I followed the instruction and create a file under mmpose\datasets\datasets\body as below:

# Copyright (c) OpenMMLab. All rights reserved.
from mmpose.registry import DATASETS
from mmengine.dataset import BaseDataset

@DATASETS.register_module(name='customdataset')
class customdataset(BaseDataset):

    METAINFO: dict = dict(from_file='configs/_base_/datasets/custom.py')

After that, I modified the init.py under mmpose\datasets\datasets\body as following:

from .aic_dataset import AicDataset
from .coco_dataset import CocoDataset
from .crowdpose_dataset import CrowdPoseDataset
from .jhmdb_dataset import JhmdbDataset
from .mhp_dataset import MhpDataset
from .mpii_dataset import MpiiDataset
from .mpii_trb_dataset import MpiiTrbDataset
from .ochuman_dataset import OCHumanDataset
from .posetrack18_dataset import PoseTrack18Dataset
from .posetrack18_video_dataset import PoseTrack18VideoDataset
from .custom_dataset import customdataset

__all__ = [
    'CocoDataset', 'MpiiDataset', 'MpiiTrbDataset', 'AicDataset',
    'CrowdPoseDataset', 'OCHumanDataset', 'MhpDataset', 'PoseTrack18Dataset',
    'JhmdbDataset', 'PoseTrack18VideoDataset', 'customdataset'
]

I am not sure how to the step 3 and did I miss something that need to be modified for dataset registry?

kdavidlp123 commented 1 year ago

It turned out that I used a wrong path of file for the data_root in the custom_mode python file. I corrected it and it started to train. Thank you for your kindness and patience !!!! You can close the issue now!! Thank you