andy-yun / pytorch-0.4-yolov3

Yet Another Implimentation of Pytroch 0.4.1 and YoloV3 on python3
MIT License
279 stars 72 forks source link

Possibly Incorrect Loss Terms #22

Closed glenn-jocher closed 5 years ago

glenn-jocher commented 5 years ago

Hello, thank you for your YOLOv3 repo. I noticed your loss term is different than the official YOLOv3 loss in at least two ways:

  1. I think you should use BCELoss for loss_cls, as the YOLOv3 paper section 2.2 clearly states "During training we use binary cross-entropy loss for the class predictions."

  2. Why is MSELoss used in place of BCELoss for loss_conf? Did you make this choice yourself or did you see this in darknet?

  3. Why divide loss_coord by 2?

https://github.com/andy-yun/pytorch-0.4-yolov3/blob/master/yolo_layer.py#L161-L164

loss_coord = nn.MSELoss(size_average=False)(coord*coord_mask, tcoord*coord_mask)/2
loss_conf = nn.MSELoss(size_average=False)(conf*conf_mask, tconf*conf_mask)
loss_cls = nn.CrossEntropyLoss(size_average=False)(cls, tcls) if cls.size(0) > 0 else 0
loss = loss_coord + loss_conf + loss_cls
jaelim commented 5 years ago

@glenn-jocher refer to a comment in https://github.com/andy-yun/pytorch-0.4-yolov3/issues/20#issuecomment-421339309

andy-yun commented 5 years ago
  1. 2, There are many rooms to improve the system. If you wonder about the codes, please modify and test them. I am used MSELoss and CrossEntropyLoss in yolov2 (region_layer.py). I did not change CrossEntropyLoss in yolov3 (yolo_layer.py). So, would you please report me which is the better when you tested on both criterion?
  2. Dividing loss by 2 means the combination of each loss. In my opinion, loss weights are depended on the task you applied to. Thus, there is no specific meaning.
doobidoob commented 5 years ago

@andy-yun @glenn-jocher I tested criterion of BCE and MSE on "conf". I only changed the conf loss function. The other is same. (x, y, w, h is MSE loss function) (Class is cross-entropy)

My experiment environment is mentioned below.

In my experiment, The BCE performance is better than MSE. (I used the https://github.com/eriklindernoren/PyTorch-YOLOv3 valid code. Please note that this is not correct mAP calculation code, but i used it for simply comparison. In that code, original YOLOv3 weight shows 0.58 mAP).

Please see the below table. image

BCE is better mAP performance than MSE, and in shell screenshot (see below), we know that BCE has some value in nPP, but MSE has still zero value on last yolo layer. However, the MSE loss is lower than BCE. I don' t know the relation between nPP and loss... image

Anyway, it does not converge to the author's performance. I think that due to the size of the batch, performance can not be achieved. So I did a simple experiment. (I only changed the conf loss fct : BCE) My experiment environment is mentioned below.

I only run 5 epoch and I got 0.152mAP. (Out of memory error occurs when 5 epoch or more. This is because the resize size increases.) Anyway this value is better than 8 batch result in same epoch.

I want to continue to train 64 batch, but I can not run because of memory error. I do not know why this problem is still happening. Please share if you have experimental results. Thanks.

door5719 commented 5 years ago

@doobidoob May be Out of memory error occurs is caused by data augmentation when resizing 608*608

andy-yun commented 5 years ago

Try to use the latest code that changed loss term to bce loss.

andy-yun commented 5 years ago

With 80 epoches, 45.9 mAP is obtained with latest code.