Closed muye5 closed 5 years ago
@glenn-jocher how do you think about this? If I make a mistake or a difference really exists.
The ground truth values are all in relative format. So you have to multiply it by the grid size of that yolo layer. Im not familar with The yololayer code you refered to, but i believe its processing predictions from the network. So maybe its not a proper comparison?
@ydixon although the ground truth values are all relative format, I think xy and wh are not relative to the same one. I think wh are relative to the input(416/608) but not grid. If training from scratch, this maybe not matter. but If using the weight file, this will cause some problems. It's about how to use the anchors, relative to grid or input.
@muye5 the only difference I’m aware of in how they are handled is xy loss is computed post sigmoid, wh is not, so the wh losses are always larger, which has bothered me a bit, but thats another topic.
In terms of the scales, they are both in grid space, i.e 0-12.
You could try inference with your proposed changes to see if it improves anything, but inference has always been extremely close to darknet. You can see recent comparisons in issue #51 for example. I don't see how there could be significant differences (such as the type you are thinking of) in the inference architecture with results as similar as #51 shows.
@glenn-jocher I have predicted on COCO test-dev2017 with the weight file in this way, got AP(IOU-0.5)=56.1, so I think this should be right and if you read the code of darknet post above, It may make the problem more clear. think about if width of anchor is 300, but exp(t_w) * w_anchor < 12, what value t_w will get. if changing to 608, value range of t_w will be different. it's more similar with a sigmoid value.
results on coco test-dev2017 with weight file provided by the author
overall performance
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.322
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.561
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.337
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.174
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.344
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.413
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.275
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.413
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.429
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.258
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.451
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.548
@muye5 thanks for running coco mAP on test-dev2017. Is this from test results provided by this repo? When you say say weight file provided by 'author' do you mean yolov3.weights
provided by joseph redmon (the darknet author)?
56.1 is very good. What resolution is this for? The official paper results are 55.3@416 and 57.9@608 (https://pjreddie.com/media/files/papers/YOLOv3.pdf).
I'm open to your idea, but I'm afraid I'm not understanding the changes you are proposing. Could you submit a PR with your proposed changes so we can get an exact idea?
@muye5 wait I just realized the lines you are quoting is only used to build targets during training. It is not used during inference (i.e. testing), so it has no impact on mAP when using pretrained yolov3.weights
.
https://github.com/ultralytics/yolov3/blob/646a5737401fe44f1ee1c679707b3a3b09995953/utils/utils.py#L227
The only lines used to convert the model outputs into boxes during inference are here: https://github.com/ultralytics/yolov3/blob/646a5737401fe44f1ee1c679707b3a3b09995953/models.py#L252-L260
@glenn-jocher mAP=56.1
is from yolov3.weights
, 608 * 608 input size (not 57.9, I am not sure whether 57.9
is on test-dev2017 or test-dev2014).
gx, gy, gw, gh = t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG * stride, t[:, 4] * nG * stride ## stride=8,16,32
I think it's a little difference. so as the procedure of predicting.
@glenn-jocher you are right, you have scale the anchor with stride. That's what I missed before.
ahhh yes, this is done in models.py, self.anchor_w is already scaled with stride. Ok good, I thought we had a mistake in there somewhere!
You aren't supposed to test on 2017test/2014test. You are supposed to test on the file named 5k.txt
and only on that. I explained it over at #9
@nirbenz the mAP on this page may not be conducted on 5k I think. It should be on test-dev(14 or 17 I'm not sure, you can read the paper of yolov3 for more details)
@muye5 the test set you need to run this on is definitely the 5k.txt
file, as is mentioned in the author's homepage. When running this I get the 55.9 score on both Darknet and this repo.
@nirbenz The first sentence on the homepage is:
You only look once (YOLO) is a state-of-the-art, real-time object detection system. On a Pascal Titan X it processes images at 30 FPS and has a mAP of 57.9% on COCO test-dev.
as my understanding, nobody will declare a self-split dataset as test-dev, because ground truth of test-dev is not open. author have made it clear he did some experiments on 5k but the performance of mAP compared with other model is on test-dev. Who will compare the performance with other models on different dataset.
@muye5 From the author's homepage (the very page you linked to):
1 classes= 80
2 train = <path-to-coco>/trainvalno5k.txt
3 valid = <path-to-coco>/5k.txt
4 names = data/coco.names
5 backup = backup
I have also expanded on the reasoning behind this in the issue I linked to:
there is no real difference between COCO14 and COCO17, other than (and this is a big thing) separation of validations and train. With COCO14 there is a roughly 50-50 split, so common practice is to merge them and choose a small subset for test. This is what 5k.part and testvalno5k.part are.
For COCO17 the dataset is already split that way (train+val and test). But because of that, evaluating YOLOv3 (using original weights) must be done on the 5k split performed by the author. Otherwise you are probably testing on some of the train-set.
Good luck.
@glenn-jocher Hi, Great job with good accuracy!
Which of scripts do you use to convert the darknet-weights to PyTorch pt-weights?
@AlexeyAB thanks! I've been working on our PyTorch YOLOv3 repo since summer 18, I think we finally have most of the kinks worked out, though we are still not quite there on training. The slight bump we report on mAP is mainly due to an updated type of NMS we created (you guys should be able to see similar bumps in darknet mAP using the same approach).
The process to convert the darknet weights to PyTorch weights is to:
@glenn-jocher Thanks!
So ultralytics/yolov3
loads Darknet cfg/weights-files directly https://github.com/ultralytics/yolov3/blob/cb352be02c7d8653bed408f5cc2cdea58145e678/detect.py#L28-L35
What command should I use to only convert Darknet model to the PyTorch model without training? I want to add the link to your repo here: https://github.com/AlexeyAB/darknet#yolo-v3-in-other-frameworks with a very short description of the way weights are converted.
@AlexeyAB I've added a simple conversion function to export from both formats to the other. The process is very simple:
git clone https://github.com/ultralytics/yolov3 && cd yolov3
# darknet to pytorch
python3 -c "from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.weights')"
Success: converted 'weights/yolov3-spp.weights' to 'converted.pt'
# pytorch to darknet
python3 -c "from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.pt')"
Success: converted 'weights/yolov3-spp.pt' to 'converted.weights'
I've added a section to our README: https://github.com/ultralytics/yolov3/blob/master/README.md#darknet-conversion
I noticed the log transformation of w/h offset in this repo was different from darknet as my understanding. utils/build_targets @Line 227
TC[b, :nTb], gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
here,xy
andwh
are all relative to grid e.g.(input=416, grid=13/26/52), so as toanchor * exp(t_w)
but in darknet,xy
are relative to grid,wh
are relative to input as my understanding,anchor * exp(t_w)
doesn't need to scale. yolo_layer.c @Line 88b.w = exp(x[index + 2*stride]) * biases[2*n] / w;
so if predicting with the weight file provided by the author, I think this will cause some mistakes.