open-mmlab / mmdetection

OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io
Apache License 2.0
29.14k stars 9.39k forks source link

Problem training on custom dataset #4053

Closed waddington closed 3 years ago

waddington commented 3 years ago

Hello,

I am experiencing some odd behavior when trying to train on a custom dataset in COCO format. So far I have tried training faster-rcnn and retinanet models but the bbox loss is always zero. Then at validation time it is reported that the "test results for entire dataset are empty". I can assure you that the JSON files for the dataset are ok, and the bbox information is there and correct, and the paths to the images are correct.

I have tried using the same dataset to train an SSD model, but I get a "list index out of range" error with respect to the class list. If I put a breakpoint in the code and inspect how many classes MMDetection thinks there is, which is less than there actually is, then I change the number of classes for the model to this number then this seems to work at first. But if I perform classwise validation, then many of the classes are listed 8 or 9 times, and not all of the classes are listed.

I have previously created a custom dataset in COCO format with no problem, and I am struggling to find a cause/solution for this issue, especially because there is a lack of useful error messages coupled with poor documentation about how the system works.

Any help is appreciated.

I also have some more general questions:

  1. The generation of IDs for images/categories/annotations in COCO format dataset. Do these IDs have to be unique across images+categories+annotations, or can a category have ID "1" and an image have ID "1" as well?
  2. Is there are a way to perform classwise validation during training?
  3. Is there a way to also save the processed images during these during-training validations? When listing class names in a custom dataset config file (or class file), does the ordering and position of a classname within the list indicate what its' ID should be? I.e. if I have list classes = ('cat', 'dog', 'monkey') should "cat" have ID "0" in the dataset files and "monkey" have ID "2"?
  4. Do all images in a dataset need to be the same size? I think not because they get resized as part of the transformations, but I'm not 100% sure on this.

Thanks

ZwwWayne commented 3 years ago
  1. yes. Category can have ID 1 and images can have ID 1. They are irrelevant.
  2. You can set classwise=True in the evaluation config for COCO dataset.
  3. 'cat' does not need to have ID '0' in the data annotation files. They can be different and for coco dataset we will build a map between them.
  4. No. They will be resized during training.

For your loss issues. Can you provide more information, e.g., the version of mmdet, the details of your dataset? We suggest you use reimplementation issues to report your issue. For now, the information is limited and we are not able to help.

waddington-ou-phd-1 commented 3 years ago

Hi @ZwwWayne , thank you for answering my questions so far.

Very sorry for the slow reply. I have added more information:

(I didn't realise that I am on a different account, but I am also OP.)

I am having an issue with a custom dataset where the loss_bbox is always 0, and then at validation time it reports "The testing results of the whole dataset is empty".

I have searched through a lot of issue's here and find people with the same problem, but I don't see a solution.

The problem:

I am using my own dataset (actually Caltech Camera Traps dataset) which I convert to MS COCO format - the format that this comes in is already very close to MS COCO format so I don't anticipate format problems (I hope I'm wrong though because that would be an easy fix!). I do not implement a custom dataset, I create a new config of type "CocoDataset" and redefine the classes. I get this issue with Faster R-CNN and RetinaNet - I haven't tried any others at this point.

Full config file:

The dataset config is copy and paste of the MS COCO dataset config, except I add classes = (...) and then in each data stage dictionary I add classes=classes which it is correctly expanding to the class names. I have also tried this with batch sizes: 1, 2, and 4, this doesn't effect the problem.

If I change the num_classes value in the model config file to 22, which is the number of classes in this dataset, I still have the problem. As an aside, is it necessary to change the number when using new datasets? -I don't think so based on documentation, but I do not know for certain: it would make sense either way.

model = dict(
    type='RetinaNet',
    pretrained='torchvision://resnet50',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN', requires_grad=True),
        norm_eval=True,
        style='pytorch'),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        start_level=1,
        add_extra_convs='on_input',
        num_outs=5),
    bbox_head=dict(
        type='RetinaHead',
        num_classes=80,
        in_channels=256,
        stacked_convs=4,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            octave_base_scale=4,
            scales_per_octave=3,
            ratios=[0.5, 1.0, 2.0],
            strides=[8, 16, 32, 64, 128]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[0.0, 0.0, 0.0, 0.0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='FocalLoss',
            use_sigmoid=True,
            gamma=2.0,
            alpha=0.25,
            loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)))
train_cfg = dict(
    assigner=dict(
        type='MaxIoUAssigner',
        pos_iou_thr=0.5,
        neg_iou_thr=0.4,
        min_pos_iou=0,
        ignore_iof_thr=-1),
    allowed_border=-1,
    pos_weight=-1,
    debug=False)
test_cfg = dict(
    nms_pre=1000,
    min_bbox_size=0,
    score_thr=0.05,
    nms=dict(type='nms', iou_threshold=0.5),
    max_per_img=100)
checkpoint_config = dict(interval=1)
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
dataset_type = 'CocoDataset'
data_root = 'data/cct/'
classes = ('opossum', 'raccoon', 'squirrel', 'bobcat', 'skunk', 'dog',
           'coyote', 'rabbit', 'bird', 'lizard', 'cat', 'badger', 'empty',
           'car', 'deer', 'cow', 'pig', 'mountain_lion', 'fox', 'bat',
           'insect', 'rodent')
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]
data = dict(
    samples_per_gpu=1,
    workers_per_gpu=2,
    train=dict(
        type='CocoDataset',
        classes=('opossum', 'raccoon', 'squirrel', 'bobcat', 'skunk', 'dog',
                 'coyote', 'rabbit', 'bird', 'lizard', 'cat', 'badger',
                 'empty', 'car', 'deer', 'cow', 'pig', 'mountain_lion', 'fox',
                 'bat', 'insect', 'rodent'),
        ann_file=
        'data/cct/annotations/caltech_bboxes_20200316-FILEPATHS-VALID-COCO-TRAIN.json',
        img_prefix='data/cct/data/',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations', with_bbox=True),
            dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
            dict(type='RandomFlip', flip_ratio=0.5),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='Pad', size_divisor=32),
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
        ]),
    val=dict(
        type='CocoDataset',
        classes=('opossum', 'raccoon', 'squirrel', 'bobcat', 'skunk', 'dog',
                 'coyote', 'rabbit', 'bird', 'lizard', 'cat', 'badger',
                 'empty', 'car', 'deer', 'cow', 'pig', 'mountain_lion', 'fox',
                 'bat', 'insect', 'rodent'),
        ann_file=
        'data/cct/annotations/caltech_bboxes_20200316-FILEPATHS-VALID-COCO-VAL.json',
        img_prefix='data/cct/data/',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(1333, 800),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='Pad', size_divisor=32),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ]),
    test=dict(
        type='CocoDataset',
        classes=('opossum', 'raccoon', 'squirrel', 'bobcat', 'skunk', 'dog',
                 'coyote', 'rabbit', 'bird', 'lizard', 'cat', 'badger',
                 'empty', 'car', 'deer', 'cow', 'pig', 'mountain_lion', 'fox',
                 'bat', 'insect', 'rodent'),
        ann_file=
        'data/cct/annotations/caltech_bboxes_20200316-FILEPATHS-VALID-COCO-VAL.json',
        img_prefix='data/cct/data/',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(1333, 800),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='Pad', size_divisor=32),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ]))
evaluation = dict(interval=1, metric='bbox')
optimizer = dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
    policy='fixed', warmup='linear', warmup_iters=500, warmup_ratio=0.001)
total_epochs = 1
work_dir = './work_dirs/retinanet-CCT-SGDLR0001M9-BS1-e1'
gpu_ids = range(0, 1)

Invocation Command

The command I used is python tools/train.py path/to/config/file

MMDetection Version

I am using version 2.8.0 from the tag "v2.8.0"


I haven't made any changes to the code, I have only added my config for the new dataset, which is "CocoDataset" type dataset in MS COCO format.

Environment

sys.platform: linux
Python: 3.7.7 (default, May  7 2020, 21:25:33) [GCC 7.3.0]
CUDA available: True
GPU 0,1: Tesla P40
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 10.1, V10.1.243
GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
PyTorch: 1.6.0
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.1 Product Build 20200208 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.5.0 (Git Hash e2ac1fac44c5078ca927cb9b90e1b3066a0b2ed0)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.1
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
  - CuDNN 7.6.3
  - Magma 2.5.2
  - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_VULKAN_WRAPPER -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF, 

TorchVision: 0.7.0
OpenCV: 4.5.1
MMCV: 1.2.5
MMCV Compiler: GCC 7.3
MMCV CUDA Compiler: 10.1
MMDetection: 2.8.0+c8c8bfe

Small Annotations file example

Annotation example:

        {
            "image_id": 5357575699102975245505110050454949101564597549751451019948565498485054494898,
            "id": 509753525353504845999810249454949101564556495799455755489757525348991009899,
            "bbox": [
                499.2,
                711.68,
                353.28000000000003,
                199.67999999999995
            ],
            "category_id": 1,
            "area": 0,
            "iscrowd": 0,
            "segmentation": []
        }

I have manually checked, the IDs do match correctly with image blocks and categories.

Image example:

        {
            "file_name": "cct_images/5a096a4f-23d2-11e8-a6a3-ec086b02610b.jpg",
            "seq_id": "700b7e45-5567-11e8-8b04-dca9047ef277",
            "height": 1536,
            "id": 5397485754975210245505110050454949101564597549751451019948565498485054494898,
            "frame_num": 1,
            "width": 2048,
            "date_captured": "2011-12-14 08:02:00",
            "rights_holder": "Erin Boydston",
            "seq_num_frames": 1,
            "location": "0"
        }

I have manually checked that the image paths are correct.

Training log

Beginning:

2021-01-11 14:36:24,615 - mmdet - INFO - workflow: [('train', 1)], max: 1 epochs
2021-01-11 14:36:37,459 - mmdet - INFO - Epoch [1][50/37456]    lr: 9.890e-06, eta: 2:39:28, time: 0.256, data_time: 0.049, memory: 1693, loss_cls: 10.0877, loss_bbox: 0.0000, loss: 10.0877
2021-01-11 14:36:47,706 - mmdet - INFO - Epoch [1][100/37456]   lr: 1.988e-05, eta: 2:23:25, time: 0.205, data_time: 0.004, memory: 1693, loss_cls: 10.0956, loss_bbox: 0.0000, loss: 10.0956
2021-01-11 14:36:57,921 - mmdet - INFO - Epoch [1][150/37456]   lr: 2.987e-05, eta: 2:17:50, time: 0.204, data_time: 0.004, memory: 1693, loss_cls: 10.0472, loss_bbox: 0.0000, loss: 10.0472
2021-01-11 14:37:08,138 - mmdet - INFO - Epoch [1][200/37456]   lr: 3.986e-05, eta: 2:14:57, time: 0.204, data_time: 0.004, memory: 1693, loss_cls: 10.0342, loss_bbox: 0.0000, loss: 10.0342
2021-01-11 14:37:18,366 - mmdet - INFO - Epoch [1][250/37456]   lr: 4.985e-05, eta: 2:13:11, time: 0.205, data_time: 0.004, memory: 1693, loss_cls: 9.9923, loss_bbox: 0.0000, loss: 9.9923
2021-01-11 14:37:28,577 - mmdet - INFO - Epoch [1][300/37456]   lr: 5.984e-05, eta: 2:11:55, time: 0.204, data_time: 0.004, memory: 1693, loss_cls: 9.8801, loss_bbox: 0.0000, loss: 9.8801
2021-01-11 14:37:38,905 - mmdet - INFO - Epoch [1][350/37456]   lr: 6.983e-05, eta: 2:11:10, time: 0.207, data_time: 0.004, memory: 1693, loss_cls: 5.9481, loss_bbox: 0.0000, loss: 5.9481
2021-01-11 14:37:49,194 - mmdet - INFO - Epoch [1][400/37456]   lr: 7.982e-05, eta: 2:10:30, time: 0.206, data_time: 0.004, memory: 1693, loss_cls: 0.0433, loss_bbox: 0.0000, loss: 0.0433

End:

2021-01-11 16:42:49,446 - mmdet - INFO - Epoch [1][37000/37456] lr: 1.000e-04, eta: 0:01:33, time: 0.203, data_time: 0.004, memory: 1693, loss_cls: 0.0000, loss_bbox: 0.0000, loss: 0.0000
2021-01-11 16:42:59,568 - mmdet - INFO - Epoch [1][37050/37456] lr: 1.000e-04, eta: 0:01:23, time: 0.202, data_time: 0.004, memory: 1693, loss_cls: 0.0000, loss_bbox: 0.0000, loss: 0.0000
2021-01-11 16:43:09,747 - mmdet - INFO - Epoch [1][37100/37456] lr: 1.000e-04, eta: 0:01:12, time: 0.204, data_time: 0.004, memory: 1693, loss_cls: 0.0000, loss_bbox: 0.0000, loss: 0.0000
2021-01-11 16:43:20,078 - mmdet - INFO - Epoch [1][37150/37456] lr: 1.000e-04, eta: 0:01:02, time: 0.207, data_time: 0.004, memory: 1693, loss_cls: 0.0000, loss_bbox: 0.0000, loss: 0.0000
2021-01-11 16:43:30,408 - mmdet - INFO - Epoch [1][37200/37456] lr: 1.000e-04, eta: 0:00:52, time: 0.207, data_time: 0.004, memory: 1693, loss_cls: 0.0000, loss_bbox: 0.0000, loss: 0.0000
2021-01-11 16:43:40,638 - mmdet - INFO - Epoch [1][37250/37456] lr: 1.000e-04, eta: 0:00:42, time: 0.205, data_time: 0.004, memory: 1693, loss_cls: 0.0000, loss_bbox: 0.0000, loss: 0.0000
2021-01-11 16:43:50,855 - mmdet - INFO - Epoch [1][37300/37456] lr: 1.000e-04, eta: 0:00:31, time: 0.204, data_time: 0.004, memory: 1693, loss_cls: 0.0000, loss_bbox: 0.0000, loss: 0.0000
2021-01-11 16:44:01,087 - mmdet - INFO - Epoch [1][37350/37456] lr: 1.000e-04, eta: 0:00:21, time: 0.205, data_time: 0.004, memory: 1693, loss_cls: 0.0000, loss_bbox: 0.0000, loss: 0.0000
2021-01-11 16:44:11,280 - mmdet - INFO - Epoch [1][37400/37456] lr: 1.000e-04, eta: 0:00:11, time: 0.204, data_time: 0.004, memory: 1693, loss_cls: 0.0000, loss_bbox: 0.0000, loss: 0.0000
2021-01-11 16:44:21,508 - mmdet - INFO - Epoch [1][37450/37456] lr: 1.000e-04, eta: 0:00:01, time: 0.205, data_time: 0.004, memory: 1693, loss_cls: 0.0000, loss_bbox: 0.0000, loss: 0.0000
2021-01-11 16:44:22,741 - mmdet - INFO - Saving checkpoint at 1 epochs
2021-01-11 17:19:55,613 - mmdet - INFO - Evaluating bbox...
2021-01-11 17:19:55,622 - mmdet - ERROR - The testing results of the whole dataset is empty.
2021-01-11 17:19:56,421 - mmdet - INFO - Exp name: retinanet-CCT-SGDLR0001M9-BS1-e1.py
2021-01-11 17:19:56,421 - mmdet - INFO - Epoch(val) [1][37456]

Let me know if you want any more information, I really appreciate any help on this. Thank you

waddington-ou-phd-1 commented 3 years ago

If I set the logging interval to 1, even the first log line reports a loss_bbox of 0.

waddington-ou-phd-1 commented 3 years ago

If I set the number of classes for the model as 21 then still have the same problem. If I reduce the learning rate to 0.00001 then still the same problem. This seems like it must be a dataset issue, but for the life of me I can't work out what.

waddington-ou-phd-1 commented 3 years ago

I have added a trailing comma to the list of classes as suggested in #3305 and still have the problem.

waddington-ou-phd-1 commented 3 years ago

I have created a subset of the annotations file with just 1 image and annotation, and I get the same error. I have then used pycocoapi to check the annotations file and it displays the image correctly with the bounding box in the correct location - so I don't think that this is a problem with the annotations files.

Annotations Subset

{
    "info": {
        "contributor": "Sara Beery",
        "version": "20200316",
        "year": 2018,
        "date_created": "2019-09-23 06:59:06.304889",
        "description": "Bounding box annotations for 63,025 images from Caltech Camera Traps, where the images only have one species label or are empty. Contains all annotations for CCT - 20, the 20-location dataset used in the ECCV18 paper \"Recognition in Terra Incognita,\" as well as additional annotations collected by MS AI for Earth."
    },
    "categories": [
        {
            "id": 6,
            "name": "bobcat"
        },
        {
            "id": 1,
            "name": "opossum"
        },
        {
            "id": 30,
            "name": "empty"
        },
        {
            "id": 9,
            "name": "coyote"
        },
        {
            "id": 3,
            "name": "raccoon"
        },
        {
            "id": 11,
            "name": "bird"
        },
        {
            "id": 8,
            "name": "dog"
        },
        {
            "id": 16,
            "name": "cat"
        },
        {
            "id": 5,
            "name": "squirrel"
        },
        {
            "id": 10,
            "name": "rabbit"
        },
        {
            "id": 7,
            "name": "skunk"
        },
        {
            "id": 14,
            "name": "lizard"
        },
        {
            "id": 99,
            "name": "rodent"
        },
        {
            "id": 21,
            "name": "badger"
        },
        {
            "id": 34,
            "name": "deer"
        },
        {
            "id": 37,
            "name": "cow"
        },
        {
            "id": 33,
            "name": "car"
        },
        {
            "id": 51,
            "name": "fox"
        },
        {
            "id": 39,
            "name": "pig"
        },
        {
            "id": 40,
            "name": "mountain_lion"
        },
        {
            "id": 66,
            "name": "bat"
        },
        {
            "id": 97,
            "name": "insect"
        }
    ],
    "annotations": [
        {
            "image_id": 5357575699102975245505110050454949101564597549751451019948565498485054494898,
            "id": 509753525353504845999810249454949101564556495799455755489757525348991009899,
            "bbox": [
                499.2,
                711.68,
                353.28000000000003,
                199.67999999999995
            ],
            "category_id": 1,
            "area": 0,
            "iscrowd": 0,
            "segmentation": []
        }
    ],
    "images": [
        {
            "file_name": "cct_images/5998cfa4-23d2-11e8-a6a3-ec086b02610b.jpg",
            "seq_id": "6f084ccc-5567-11e8-bc84-dca9047ef277",
            "height": 1494,
            "id": 5357575699102975245505110050454949101564597549751451019948565498485054494898,
            "frame_num": 1,
            "width": 2048,
            "date_captured": "2011-05-13 23:43:18",
            "rights_holder": "Justin Brown",
            "seq_num_frames": 3,
            "location": "33"
        }
    ]
}

Simple pycocoapi script

from pycocotools.coco import COCO
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import pylab

images_dir = './../data/'
ann_file = "./manual.json"

coco = COCO(ann_file)

# display COCO categories and supercategories
cats = coco.loadCats(coco.getCatIds())
nms=[cat['name'] for cat in cats]
print('COCO categories: \n{}\n'.format(' '.join(nms)))

# get all images containing given categories, select one at random
catIds = coco.getCatIds(catNms=['opossum']);
imgIds = coco.getImgIds(catIds=catIds );
img = coco.loadImgs(imgIds[np.random.randint(0,len(imgIds))])[0]

print(img)

# load image
I = io.imread(images_dir + img['file_name'])

# load and display instance annotations
plt.imshow(I); plt.axis('off')
annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None)
anns = coco.loadAnns(annIds)
coco.showAnns(anns, True)
plt.show()

Images

Original: 5998cfa4-23d2-11e8-a6a3-ec086b02610b

With bbox: Figure_1

waddington-ou-phd-1 commented 3 years ago

With some testing, num_classes must be 22 for this dataset, and cannot be left unchanged at 80 otherwise there will be a ListIndexError at validation time.

If I reduce all of the confidence thresholds to 0.0 in the model config files then I do not get the "The testing results of the whole dataset is empty" error - but obviously this is not good. This suggests that the model is only producing very low confidence predictions, which is fair enough, but this is contradicted by the next point.

Even still though, I have the issue where loss_bbox is always 0. If the model was producing low confidence predictions, and they are not supressed when thresholds are set to 0, then I would expect the loss value to be very high, not 0. I would expect a loss value of anything else other than 0 regardless of the predictions made, I can't imagine a situation where a genuine 0 value loss would be real.

waddington-ou-phd-1 commented 3 years ago

This problem seems to be caused by #4436