zylo117 / Yet-Another-EfficientDet-Pytorch

The pytorch re-implement of the official efficientdet with SOTA performance in real time and pretrained weights.
GNU Lesser General Public License v3.0
5.21k stars 1.27k forks source link

Why do you get the result more slowly than in the official paper? #77

Open AlexeyAB opened 4 years ago

AlexeyAB commented 4 years ago

@zylo117 Hi! Nice work!

  1. Can you explain please, why do you get the result more slowly than in the official paper?

  2. Does the official code https://github.com/google/automl/tree/master/efficientdet not reproduce the results from the article?

From the official paper Official D0 41.67 FPS on Titan V - it is about ~33.33 FPS on RTX 2080 Ti, while you get only 2.09 FPS with official D0 on RTX 2080 Ti.

coefficient Time FPS Ratio mAP 0.5:0.95
Official D0 (tf postprocess) 0.713s 1.40 1X 33.8
Official D0 (numpy postprocess) 0.477s 2.09 1.49X 33.8
Yet-Another-EfficientDet-D0 0.028s 36.20 (best) 25.86X 32.6
Official D0 from paper Titan V https://arxiv.org/abs/1911.09070v4 0.024s (Titan V) 41.67 29.76X 33.8
Official D0 from paper (calculated for RTX 2080 Ti) 0.030s (RTX2080 Ti) ~33.33 23.80X 33.8 (best)

https://arxiv.org/abs/1911.09070v4

image

zylo117 commented 4 years ago

I have no idea why tf version is so slow, but it's true, it is slow. And it's not just me, The community also found out tf version is much slower.

check this out: https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch/issues/4#issuecomment-611337716

It's 16ms on the paper, but it runs at 200-300ms, 700 if it's a complex image.

My best bet is that there is a bug or environment issue in tf effdet.

AlexeyAB commented 4 years ago

@mingxingtan Hi, please could you explain the cause of this problem?

mingxingtan commented 4 years ago

@zylo117 Great to see you make the pytorch version much faster! Congratulations!

@AlexeyAB Our paper mainly focuses on network architecture, so the latency is for network architecture excluding post-processing / nms. When open sourcing it, I was lazy and just copied the same python-based post-processing from TF-RetinaNet, which is slow since it purely runs with python on CPU.

See https://github.com/google/automl/issues/97, I will try to prioritize speeding up post-processing. (+ @fsx950223, who is the main contributor for this part).

AlexeyAB commented 4 years ago

@mingxingtan Thanks!

Did you develope EfficientDet primarily for GPU or TPU? Since there are used: Grouped-convolutional layers and SE-blocks which are slow for GPU/TPU-edge.


So the results 32.6 AP / 36.20 FPS of the fastest implementation of EfficientDet are consistent with result 33.8 AP / 33.33 FPS RTX 2080 Ti from the article. While official implementation is slower 33.8 AP / 2.09 FPS, because FPS for the paper was calculated only for network-inference exclude post-processing + nms.

glenn-jocher commented 4 years ago

@AlexeyAB @mingxingtan I think the large differences in timing is due to NMS. Python NMS can often be >10X slower than C/CUDA NMS, so if python NMS is used it can easily dominate the total detection time. Pytorch updated their cpu/gpu NMS code last year, and now NMS operations are very fast, so this is probably why this repo is showing faster speeds.

Over on https://github.com/ultralytics/yolov3, the average speed for each image across the 5000 coco test images on a V100 using yolov3-spp (at 43mAP) is:

Speed: 11.4/2.1/13.5 ms inference/NMS/total per 608x608 image at batch-size 1
AlexeyAB commented 4 years ago

@glenn-jocher

In the C Darknet implementation https://github.com/AlexeyAB/darknet NMS ~ 0.0ms, but it depends on number of detected bboxes.

On RTX 2070: https://github.com/AlexeyAB/darknet/issues/4497#issuecomment-564168707

Speed: 17.2/0.3/0.4/0.0/17.9 ms - inference/[yolo]/get_network_boxes()/NMS/total per 416x416 image at batch-size 1

glenn-jocher commented 4 years ago

@AlexeyAB ah this is very fast too! In the past when I had python NMS it might take up to 10-30 ms or more for NMS per image, so this reduction down to 2 ms is a great speedup for me. The time I posted is for testing (i.e. very low --conf 0.001), which will generate many hundreds/thousands of boxes needing NMS. For regular inference (i.e. --conf 0.50) the NMS cost should be closer to near zero as your number shows.

Maybe this makes it especially odd that the TF efficientdet postdetection time is slow for regular inference. Part of the slowdown is surely due to the grouped convolutions, but this number should be baked in to Table 2 I assume.

fsx950223 commented 4 years ago

Could you explain why your official result is much slower than my cpu?

YashasSamaga commented 4 years ago

I think it's just the cuDNN's heuristics making a mistake in the algorithm selection. You need to override cuDNN's heuristics to use "unfused" IMPLCIT_GEMM for depthwise convolution. Other algorithms are much slower than CPU for depthwise convolution.

I have seen models like MobileNet, EfficientNet YOLO, etc. become 10x faster after forcing cuDNN to use IMPLICIT_GEMM for depthwise convolution.

Since this model is using depthwise convolution, I speculate that this might be causing the problems.

bloom1123 commented 2 years ago

@zylo117 Hi! Nice work!

1. Can you explain please, why do you get the result more slowly than in the official paper?

2. Does the official code https://github.com/google/automl/tree/master/efficientdet not reproduce the results from the article?

From the official paper Official D0 41.67 FPS on Titan V - it is about ~33.33 FPS on RTX 2080 Ti, while you get only 2.09 FPS with official D0 on RTX 2080 Ti.

* Titan V - 134 TFLops-Tensor-Cores ( 1.25x)

* RTX 2080 Ti - 107 TFlops-Tensor-Cores (1x)

coefficient Time FPS Ratio mAP 0.5:0.95 Official D0 (tf postprocess) 0.713s 1.40 1X 33.8 Official D0 (numpy postprocess) 0.477s 2.09 1.49X 33.8 Yet-Another-EfficientDet-D0 0.028s 36.20 (best) 25.86X 32.6 Official D0 from paper Titan V https://arxiv.org/abs/1911.09070v4 0.024s (Titan V) 41.67 29.76X 33.8 Official D0 from paper (calculated for RTX 2080 Ti) 0.030s (RTX2080 Ti) ~33.33 23.80X 33.8 (best)

https://arxiv.org/abs/1911.09070v4

image

you use the Yet-Another-EfficientDet-D0 and get the 36.2 FPS using efficientdet_test code?