ViTAE-Transformer / ViTPose

The official repo for [NeurIPS'22] "ViTPose: Simple Vision Transformer Baselines for Human Pose Estimation" and [TPAMI'23] "ViTPose++: Vision Transformer for Generic Body Pose Estimation"
Apache License 2.0
1.27k stars 171 forks source link

Question about how to train mutiple datasets. #106

Open Ironbrotherstyle opened 1 year ago

Ironbrotherstyle commented 1 year ago

Hi, thank you for your work. I have a question about how to train multiple datasets. For example, when I wanted to train aic and COCO dataset together, I used modifed config file from vitPose+_large_coco+aic+mpii+ap10k+apt36k+wholebody_256x192_udp.py, but got error about num_joints beacuse aic and COCO have differtent number of joints. image I wonder if there is any advice on training aic, COCO and mpii together. THX.

ValterH commented 1 year ago

I suppose the issue is that AIC has 14 keypoints and thus 14 channels. I think the multi dataset config is intended to train a individual head for each dataset. If you are trying to train a single head you should somehow add the misaligned COCO joints to the AIC annottions. However I think the original config makes more sense.

Ironbrotherstyle commented 1 year ago

I suppose the issue is that AIC has 14 keypoints and thus 14 channels. I think the multi dataset config is intended to train a individual head for each dataset. If you are trying to train a single head you should somehow add the misaligned COCO joints to the AIC annottions. However I think the original config makes more sense.

Thank you for your reply. I tried to add associte_keypoints_head for aic and mpii dataset into config .py, part of the code is as follows:

# model settings
model = dict(
    type='TopDown',
    pretrained=None,
    backbone=dict(type='DNet', encoding="031_40_-11-21112-111112"),

    keypoint_head=dict(
        type='TopdownHeatmapSimpleHead',
        in_channels=1024,
        out_channels=channel_cfg['num_output_channels'],
        loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
       associate_keypoint_head=[
        dict(
            type='TopdownHeatmapSimpleHead',
            in_channels=1024,
            num_deconv_layers=2,
            num_deconv_filters=(256, 256),
            num_deconv_kernels=(4, 4),
            extra=dict(final_conv_kernel=1, ),
            out_channels=aic_channel_cfg['num_output_channels'],
            loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
        dict(
            type='TopdownHeatmapSimpleHead',
            in_channels=1024,
            num_deconv_layers=2,
            num_deconv_filters=(256, 256),
            num_deconv_kernels=(4, 4),
            extra=dict(final_conv_kernel=1, ),
            out_channels=mpii_channel_cfg['num_output_channels'],
            loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),

but got error: image It seems that the project now doses not support the associate_keypoint_head setting. I am very curious about how the author successfully trained the following model: image. The provided config file ViTPose_large_ochuman_256x192.py only uses one COCO dataset:

data_root = 'data/ochuman'
data = dict(
    samples_per_gpu=64,
    workers_per_gpu=2,
    val_dataloader=dict(samples_per_gpu=32),
    test_dataloader=dict(samples_per_gpu=32),
    train=dict(
        type='TopDownCocoDataset',
        ann_file='data/coco/annotations/person_keypoints_train2017.json',
        img_prefix='data/coco//train2017/',
        data_cfg=data_cfg,
        pipeline=train_pipeline,
        dataset_info={{_base_.dataset_info}}),
    val=dict(
        type='TopDownOCHumanDataset',
        ann_file=f'{data_root}/annotations/'
        'ochuman_coco_format_val_range_0.00_1.00.json',
        img_prefix=f'{data_root}/images/',
        data_cfg=data_cfg,
        pipeline=val_pipeline,
        dataset_info={{_base_.dataset_info}}),
    test=dict(
        type='TopDownOCHumanDataset',
        ann_file=f'{data_root}/annotations/'
        'ochuman_coco_format_test_range_0.00_1.00.json',
        img_prefix=f'{data_root}/images/',
        data_cfg=data_cfg,
        pipeline=val_pipeline,
        dataset_info={{_base_.dataset_info}}),
)
phc-alchera commented 8 months ago

did you solved? I don't review all of code yet, mm-based model depends on some parameters.

broadcast problem need to add 'max_num_joints'

data_cfg = dict(
    image_size=[192, 256],
    heatmap_size=[48, 64],
    num_output_channels=channel_cfg['num_output_channels'],
    num_joints=channel_cfg['dataset_joints'],
    dataset_channel=channel_cfg['dataset_channel'],
    inference_channel=channel_cfg['inference_channel'],
    soft_nms=False,
    nms_thr=1.0,
    oks_thr=0.9,
    vis_thr=0.2,
    use_gt_bbox=False,
    det_bbox_thr=0.0,
    bbox_file='/hdd1/Dataset/coco/person_detection_results/'
    'COCO_val2017_detections_AP_H_56_person.json',
    max_num_joints=133,
    dataset_idx=0,

if you want activate 'associated_keypoint_head', insert "num_expert=number of your datasets, part_features=320" in model config

# model settings
model = dict(
    type='TopDownMoE',
    pretrained=None,
    backbone=dict(
        type='ViTMoE',
        img_size=(256, 192),
        patch_size=16,
        embed_dim=1280,
        depth=32,
        num_heads=16,
        ratio=1,
        use_checkpoint=False,
        mlp_ratio=4,
        qkv_bias=True,
        drop_path_rate=0.55,
        num_expert=2,
        part_features=320
    ),
    keypoint_head=dict(
        type='TopdownHeatmapSimpleHead',
        in_channels=1280,
        num_deconv_layers=2,
        num_deconv_filters=(256, 256),
        num_deconv_kernels=(4, 4),
        extra=dict(final_conv_kernel=1, ),
        out_channels=channel_cfg['num_output_channels'],
        loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
    associate_keypoint_head=[
        dict(
            type='TopdownHeatmapSimpleHead',
            in_channels=1280,
            num_deconv_layers=2,
            num_deconv_filters=(256, 256),
            num_deconv_kernels=(4, 4),
            extra=dict(final_conv_kernel=1, ),
            out_channels=crowdpose_channel_cfg['num_output_channels'],
            loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))],
    train_cfg=dict(),
    test_cfg=dict(
        flip_test=True,
        post_process='default',
        shift_heatmap=False,
        target_type=target_type,
        modulate_kernel=11,
        use_udp=True))

I hope this helps.