facebookresearch / detr

End-to-End Object Detection with Transformers
Apache License 2.0
13.39k stars 2.41k forks source link

Quick way to get by class evaluation results? #106

Closed linzzzzzz closed 4 years ago

linzzzzzz commented 4 years ago

Thank you for the great work!

I'm curious if there is a quick way to see by class evaluation results on val5k(or the evaluation results for a specific class)?

Here is how I'm currently doing:

I added self.coco_eval[iou_type].params.catIds = [72] at coco_eval.py after line 31 to get eval results for class 72.

https://github.com/facebookresearch/detr/blob/10a2c759454930813aeac7af5e779f835dcb75f5/datasets/coco_eval.py#L31

I feel this should work but not very sure since the coco_eval.py file looks quite complicated to me. Could you help me confirm if what I'm doing would give accurate results? Thanks!

szagoruyko commented 4 years ago

the easiest is to add output_dir argument to the evaluation utility, for example:

python main.py --batch_size 2 --no_aux_loss --eval \
  --resume https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth \
  --coco_path /path/to/coco --output_dir .

this saves eval.pth file in the provided dir, so you can load it and check per-class AP50 with this code:

import torch
import pandas as pd
from pathlib import Path

def load_eval(eval_path):
    data = torch.load(eval_path)
    # precision is n_iou, n_points, n_cat, n_area, max_det
    precision = data['precision']
    # take precision for all classes, all areas and 100 detections
    CLASSES = [
        'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
        'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
        'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
        'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
        'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
        'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
        'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
        'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
        'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
        'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
        'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
        'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
    ]
    CLASSES = [c for c in CLASSES if c != 'N/A']
    area = 0
    return pd.DataFrame.from_dict({c: p for c, p in zip(CLASSES, precision[0, :, :, area, -1].mean(0) * 100)}, orient='index')

with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    df = load_eval('/tmp/detr_tmp/eval.pth')
    display(df)

hope this helps

fmassa commented 4 years ago

Following @szagoruyko answer, I'm closing this issue, but let us know if you have further questions.