experiencor / keras-yolo2

Easy training on custom dataset. Various backends (MobileNet and SqueezeNet) supported. A YOLO demo to detect raccoon run entirely in brower is accessible at https://git.io/vF7vI (not on Windows).
MIT License
1.73k stars 785 forks source link

Maybe a terrible mistake #353

Open rockeyben opened 6 years ago

rockeyben commented 6 years ago
conf_mask = conf_mask + tf.to_float(best_ious < 0.6) * (1 - y_true[..., 4]) * self.no_object_scale
conf_mask = conf_mask + y_true[..., 4] * self.object_scale
nb_conf_box = tf.reduce_sum(tf.to_float(conf_mask > 0.0))
loss_conf = tf.reduce_sum(tf.square(true_box_conf-pred_box_conf) * conf_mask) / (nb_conf_box + 1e-6) / 2.

If you do this in this way, rather than compute neg conf loss and pos conf loss seperately, the neg conf loss will overwhelm the pos conf loss, and the model will converge to a bad status. Just a simple case:

3*x/x + 5*y/y != (3*x+5*y)/(x+y)

And if one of x or y is too large, the imbalance will be more significant

rockeyben commented 6 years ago

I can't use your model to fit 4 images, but if I seperate the neg conf loss and pos conf loss, it works. Just FYI

rodrigo2019 commented 6 years ago

Did you compared a trainning with your modifications vs the original? could you share the results?

bdhammel commented 5 years ago

@rockeyben Could you post your modified loss calculation?

rockeyben commented 5 years ago

@rodrigo2019 Sorry, I am doing other work recently so I can't conclude my result until now. But I am doing experiments in a different setteing. I am trying to adjust this repo to implement bird eye view car detection on KITTI. ECCV18: YOLO3D: End-to-end real-time 3D Oriented Object Bounding Box Detection from LiDAR Point Cloud In this task, we have to detect 3D box from bird eye view, which means the car is relatively small, and positive part of output score map is very small. It should be mentioned that author adjust darknet from 1/32 downsampling to 1/16 downsampling to overcome this situation, but it still makes the imbalance between pos and neg part in scoremap even worse. (neg part overwhelm pos part)

So I use this repo's loss function and I find no matter how low the loss become, this network still can't perform well on only 4 images (performs well on 1 image, but fails on others). And if I train in the whole dataset, I found the conf score of box is relatively low, a perfectly predicted box only get 0.6 or 0.7 score, so I wondered if it was the problem of neg-pos. First I try to change the parameter of neg and pos loss, so that the imbalace will be overcomed, but I find it doesn't work even I change the ratio to 20:1(pos:neg). So I changed the loss function to following version: @bdhammel

conf_mask_neg = tf.to_float(best_ious < 0.6) * (1 - y_true[..., 4]) * self.no_object_scale
conf_mask_pos = y_true[..., 4] * self.object_scale
nb_conf_box_neg = tf.reduce_sum(tf.to_float(conf_mask_neg > 0.0))
nb_conf_box_pos = tf.reduce_sum(tf.to_float(conf_mask_pos > 0.0))
loss_conf_neg = tf.reduce_sum(tf.square(true_box_conf-pred_box_conf) * conf_mask_neg) / (nb_conf_box_neg + 1e-6) / 2.
loss_conf_pos = tf.reduce_sum(tf.square(true_box_conf-pred_box_conf) * conf_mask_pos) / (nb_conf_box_pos + 1e-6) / 2.
loss_conf = loss_conf_neg + loss_conf_pos

And I find it finally performs normally.(even under ratio of 5:3, pos:neg)

I haven't done experiments on PASCAL VOC or COCO, though in this setting, the output score map is only 13x13 (my version is 38x38), and target objects are often large in the image, so the problem may not be that serious, but I think the problem still exists. If you are doing experiments on VOC and COCO, welcome to post your result here, thanks @bdhammel @rodrigo2019

Best

rodrigo2019 commented 5 years ago

thank you for sharing, I will try to test in this week.

rodrigo2019 commented 5 years ago

Could you share the full code? I dont know exactly where I need to insert this code.

bdhammel commented 5 years ago

@rodrigo2019

Line 187 and Line 184 will become

conf_mask_neg = tf.to_float(best_ious < 0.6) * (1 - y_true[..., 4]) * self.no_object_scale
conf_mask_pos = y_true[..., 4] * self.object_scale

Line 212 will become

nb_conf_box_neg = tf.reduce_sum(tf.to_float(conf_mask_neg > 0.0))
nb_conf_box_pos = tf.reduce_sum(tf.to_float(conf_mask_pos > 0.0))

and Line 217 will become

loss_conf_neg = tf.reduce_sum(tf.square(true_box_conf-pred_box_conf) * conf_mask_neg) / (nb_conf_box_neg + 1e-6) / 2.
loss_conf_pos = tf.reduce_sum(tf.square(true_box_conf-pred_box_conf) * conf_mask_pos) / (nb_conf_box_pos + 1e-6) / 2.
loss_conf = loss_conf_neg + loss_conf_pos

Unfortunately, after these changes, I am still unable to get the network to pass the sanity check of overfitting to a small dataset. By that I mean, loss and mAP will reflect the network has overfit, but plotting the detections shows incorrect box numbers and placement (Maybe suggestive of a different problem, potentially on my side). For comparison, without these changes, loss reflected being able to overfit to a small dataset but mAP did not. This is of course given the small dataset I'm working with, so this shouldn't be a blanket statement for the network's behavior in all situations, but it is something to note

rodrigo2019 commented 5 years ago

@bdhammel, you are saying that with your modifications you got a better map, but when you make sanity check ploting the dataset you saw worst results?

bdhammel commented 5 years ago

Not worse results, just bad results (results not consistent with the obtained mAP). Maybe a bit better relative to the old loss (but that's a very qualitative metric of me eyeballing a few images).

Oddly, when I trained on the full dataset with the new loss, I got very respectable results (i.e. better results than when I tried to overfit to a small dataset). I'll also note this was not the case with the old loss, where the network had converged to only predicting background (same behavior rockeyben reported)

rockeyben commented 5 years ago

It's weird, could you check your visualization code? If mAP is OK, say, 80, 90, 90+..., I think visualization should be normal too.

rodrigo2019 commented 5 years ago

Did you checked the mAP evalatuation function? the IOU threshold is set to 0.3 instead 0.5

rodrigo2019 commented 5 years ago

I made a training using these parameters, and using your modification in loss function

"object_scale":         5.0 ,
"no_object_scale":      3.0,
"coord_scale":          1.0,
"class_scale":          1.0,

I got better mAP, from 0.875 to 0.901 in my own dataset (around 5000 objects from single class). But I got much higher false positive rate, and now I understand what @bdhammel said, the results look like are not consistent with obtained mAP.

rockeyben commented 5 years ago

Is it the problem of score threshold? Because in my setting, if I use my new loss, the score of all boxes will goes up, say, avg 0.5 to avg 0.8, so you may turn up the threshold to kill FP. Or you plotted the score of FP and find it is a 'real' FP with higher score than a TP. If that was the case, you may change the parameters and turn down the 'object_scale' a little bit. And if 90.1 is the result on your training set, could you share your training epoch, maybe it is a problem of undertraining. Because I think iou 0.3 is too low for mAP evaluation. You could change the iou to 0.5 and keep training until mAP turns to overfit in this iou setting.

bdhammel commented 5 years ago

I'll start this with the disclaimer that I really have not had the chance to dive into this issue. This is a project I'm trying to juggle with my left-hand while l have a lot of other stuff going on, so I can't rule out the problem isn't due to "piloting error." And unfortunately, I won't be able to revisit it for a while to give it the correct attention. But, in case anyone wants to reproduce what I'm seeing:

rockeyben commented 5 years ago

@bdhammel I know what happens: Line 471 in frontend.py If we use

yolo.predict()

We will enter function in Line 77 in utils.py

def decode_netout(netout, anchors, nb_class, obj_threshold=0.3, nms_threshold=0.3):

We can see if we send no parameter to this function, object threshold will be default 0.3. Actually through your visualization image, it seems that the true positive get score of 0.709, while other FP get around 0.3, so what you need to do is to turn up the object threshold to kill FP.

rodrigo2019 commented 5 years ago

Should I use the model with best validation loss or best mAP? because the best mAP model looks doesnt care about the false positives.

ps: I edited the code to do a mAP validation for every epoch

rockeyben commented 5 years ago

It may depends on your requirements. For me, intuitively, mAP will be better, but if it produce false positives, you can adjust your post processing code to kill them.