leondgarse / keras_cv_attention_models

Keras beit,caformer,CMT,CoAtNet,convnext,davit,dino,efficientdet,edgenext,efficientformer,efficientnet,eva,fasternet,fastervit,fastvit,flexivit,gcvit,ghostnet,gpvit,hornet,hiera,iformer,inceptionnext,lcnet,levit,maxvit,mobilevit,moganet,nat,nfnets,pvt,swin,tinynet,tinyvit,uniformer,volo,vanillanet,yolor,yolov7,yolov8,yolox,gpt2,llama2, alias kecam
MIT License
587 stars 92 forks source link

Evaluation of YOLOX with eval_script #50

Closed tylertroy closed 2 years ago

tylertroy commented 2 years ago

This issue has two parts. For the first I propose a solution but include it here incase my proposed solution effects part 2.

  1. Error in eval script when using anchor free mode.
  2. Bad/erroneous results from evaluation

Part 1 When running the eval_func.py with the --use_anchor_free_mode flag the following error is raised.

>>>> Using anchor_free_mode decode_predictions: {'aspect_ratios': [1], 'num_scales': 1, 'anchor_scale': 1, 'grid_zero_start': True}
Traceback (most recent call last):
  File "./eval_script.py", line 154, in <module>
    run_coco_evaluation(
  File "/home/lookdeep/gits/keras_cv_attention_models/keras_cv_attention_models/coco/eval_func.py", line 216, in run_coco_evaluation
    pred_decoder = DecodePredictions(input_shape, pyramid_levels, **ANCHORS, use_object_scores=use_anchor_free_mode)
TypeError: __init__() got an unexpected keyword argument 'aspect_ratios'

The issue lies in DecodePredictions (keras_cv_attention_models/blob/main/keras_cv_attention_models/coco/eval_func.py#L27) whereby anchor-free keyword arguments are not specified.

Can be fixed with

diff --git a/keras_cv_attention_models/coco/eval_func.py b/keras_cv_attention_models/coco/eval_func.py
index faafb7a..414e3ee 100644
--- a/keras_cv_attention_models/coco/eval_func.py
+++ b/keras_cv_attention_models/coco/eval_func.py
@@ -24,7 +24,7 @@ class DecodePredictions:
     >>> # bboxes = array([[0.433231  , 0.54432285, 0.8778939 , 0.8187578 ]], dtype=float32), labels = array([17]), scores = array([0.85373735], dtype=float32)
     """

-    def __init__(self, input_shape=512, pyramid_levels=[3, 7], anchor_scale=4, use_anchor_free_mode=False, use_object_scores="auto", aspect_ratios=[1], num_scales=1, grid_zero_start=True):
+    def __init__(self, input_shape=512, pyramid_levels=[3, 7], anchor_scale=4, use_anchor_free_mode=False, use_object_scores="auto"):
         self.pyramid_levels = list(range(min(pyramid_levels), max(pyramid_levels) + 1))
         self.use_object_scores = use_anchor_free_mode if use_object_scores == "auto" else use_object_scores
         if use_anchor_free_mode:

Part 2

When attempting to evaluate yolox_s_coco.h5 I get poor results which don't line up with inference results. E.g.

CUDA_VISIBLE_DEVICES='0' python ./eval_script.py -d coco -m $HOME/.keras/models/yolox_s_coco.h5 --use_anchor_free_mode
... [stderr excluded]
>>>> COCO evaluation: coco/2017 - anchors: {'pyramid_levels': [3, 5], 'use_anchor_free_mode': True}
>>>> Using input_shape (640, 640) for Keras model.
>>>> rescale_mode: torch
>>>> Using anchor_free_mode decode_predictions: {'aspect_ratios': [1], 'num_scales': 1, 'anchor_scale': 1, 'grid_zero_start': True}
... [stderr excluded]
loading annotations into memory...
Done (t=0.82s)
creating index...
index created!
Loading and preparing results...
Converting ndarray to lists...
(195236, 7)
0/195236
DONE (t=1.39s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *bbox*
DONE (t=23.52s).
Accumulating evaluation results...
DONE (t=3.52s).
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.001
leondgarse commented 2 years ago
leondgarse commented 2 years ago

This is a result testing YOLOR_CSP, which is also lower than official reported:

CUDA_VISIBLE_DEVICES='1' ./eval_script.py -m yolor.YOLOR_CSP -d coco --batch_size 2
#  Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.444
#  Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.610
#  Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.483
#  Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.301
#  Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.490
#  Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.566
#  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.327
#  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.543
#  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.586
#  Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.422
#  Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.628
#  Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.711
leondgarse commented 2 years ago

With above 2 commits, yolox evaluation is working correctly now. That it's using BGR input format... YOLOXS results also updated in coco#evaluation

CUDA_VISIBLE_DEVICES='1' ./coco_eval_script.py -m yolox.YOLOXTiny --nms_method hard --nms_iou_or_sigma 0.65 --use_bgr_input --use_anchor_free_mode
 # Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.329
 # Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.504
 # Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.349
 # Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.138
 # Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.360
 # Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.499
 # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.287
 # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.458
 # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.486
 # Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.230
 # Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.549
 # Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.692

But yolor still not right, I think maybe realted with its letterbox function, still trying.

CUDA_VISIBLE_DEVICES='1' ./coco_eval_script.py -m yolor.YOLOR_CSP --nms_method hard --nms_iou_or_sigma 0.65 --use_yolor_anchors_mode
# Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.488
# Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.674
# Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.530
# Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.324
# Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.539
# Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.627
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.365
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.592
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.634
# Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.447
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.684
# Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.779
leondgarse commented 2 years ago

YOLOR evaluation finaly fixed:

CUDA_VISIBLE_DEVICES='1' ./coco_eval_script.py -m yolor.YOLOR_CSP --use_yolor_anchors_mode --nms_method hard --nms_iou_or_sigma 0.65 \
--nms_max_output_size 300 --nms_topk -1 --letterbox_pad 64 --input_shape 704
# Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.500
# Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.686
# Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.544
# Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.340
# Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.551
# Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.643
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.380
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.627
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.683
# Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.529
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.735
# Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.817

More compare results can be found coco#evaluation