lyuwenyu / RT-DETR

[CVPR 2024] Official RT-DETR (RTDETR paddle pytorch), Real-Time DEtection TRansformer, DETRs Beat YOLOs on Real-time Object Detection. 🔥 🔥 🔥
Apache License 2.0
2.21k stars 242 forks source link

Validation loss larger/has different scale than training loss #141

Open hashJoe opened 9 months ago

hashJoe commented 9 months ago

Hey, I inserted criterion losses into metric_logger to have a better overview of the evaluation besides coco_eval_bbox. I simply uncommented the lines referenced below and returned the losses.

https://github.com/lyuwenyu/RT-DETR/blob/3330eca679a7d7cce16bbb10509099174a2f40bf/rtdetr_pytorch/src/solver/det_engine.py#L123-L133

However, the values of the losses differ than those of the training losses. Below are examples of train and test losses of the same epoch:

e.g., Train

Averaged stats: lr: 0.000010  loss: 6.0684 (6.0180)  loss_bbox: 0.1730 (0.1718)  loss_bbox_aux_0: 0.1756 (0.1796)  loss_bbox_aux_1: 0.1725 (0.1743)  loss_bbox_aux_2: 0.2110 (0.2254)  loss_bbox_dn_0: 0.2406 (0.2460)  loss_bbox_dn_1: 0.1696 (0.1705)  loss_bbox_dn_2: 0.1638 (0.1622)  loss_giou: 0.2230 (0.2254)  loss_giou_aux_0: 0.2216 (0.2319)  loss_giou_aux_1: 0.2151 (0.2276)  loss_giou_aux_2: 0.2528 (0.2703)  loss_giou_dn_0: 0.2726 (0.2790)  loss_giou_dn_1: 0.2003 (0.2098)  loss_giou_dn_2: 0.1918 (0.2026)  loss_vfl: 0.4495 (0.4768)  loss_vfl_aux_0: 0.5828 (0.5614)  loss_vfl_aux_1: 0.4664 (0.4884)  loss_vfl_aux_2: 0.6356 (0.6641)  loss_vfl_dn_0: 0.3318 (0.3241)  loss_vfl_dn_1: 0.2707 (0.2673)  loss_vfl_dn_2: 0.2640 (0.2594)

Test/Val

Averaged stats: loss: 31790.3984 (32923.0916)  loss_bbox: 31785.6328 (32918.2895)  loss_giou: 4.3206 (4.4140)  loss_vfl: 0.4160 (0.3878)  loss_bbox_unscaled: 6357.1265 (6583.6578)  loss_giou_unscaled: 2.1603 (2.2070)  loss_vfl_unscaled: 0.4160 (0.3878)

loss_vfl seem to be alright, but loss_bbox and loss_giou are much larger, but decreasing in values as the model is trained for more epochs.

Is this an issue? What's the reason behind this? Thank you!

amndzr commented 7 months ago

@hashJoe Hey, for me the loss_bbox and loss_giou are not decresing, but rather constant. Do you have an idea what the reason could be? My train_loss is going down and the metrics up.

How big is your dataset?

hashJoe commented 7 months ago

@hashJoe Hey, for me the loss_bbox and loss_giou are not decresing, but rather constant. Do you have an idea what the reason could be? My train_loss is going down and the metrics up.

How big is your dataset?

@amndzr Around 5K images and maybe 13K boxes. Losses are going down and results are good. Are your results still looking good?

Try debugging into this and check what's wrong. https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_criterion.py#L152-L171

amndzr commented 7 months ago

@hashJoe I cant seem to find anything suspicious when debugging. My results are good in respect to train_loss, AP as well as AR. My Dataset has a size of around 10K images.

grafik

20231211 commented 4 months ago

Hello, I have the same problem as you, do you know the cause of the problem, and how did you solve it?

hashJoe commented 4 months ago

@amndzr I see it is just the validation loss.. the same reply below might help: @20231211 The transforms applied in validation stage on the bounding boxes differ than that of the training stage. The training stage, the boxes coordinate output format are changed to cxcywh and normalized as well: https://github.com/lyuwenyu/RT-DETR/blob/64878acad2f58ed34579e5a5ec45da1044587e09/rtdetr_pytorch/src/data/transforms.py#L132-L139

set in: https://github.com/lyuwenyu/RT-DETR/blob/64878acad2f58ed34579e5a5ec45da1044587e09/rtdetr_pytorch/configs/rtdetr/include/dataloader.yml#L20

where as in the validation stage, this fact is ignored and therefore unnormalized, which result in wrong scaling to compute in the criterion. https://github.com/lyuwenyu/RT-DETR/blob/64878acad2f58ed34579e5a5ec45da1044587e09/rtdetr_pytorch/configs/rtdetr/include/dataloader.yml#L35

@lyuwenyu Could you please check this? Thanks! :)

20231211 commented 4 months ago

@hashJoe Thank you for your help, so I should replace {type: ConvertDtype} with {type: ConvertBox, out_fmt: 'cxcywh', normalize: True}, right? Do I need to make any other changes?

hashJoe commented 4 months ago

@20231211 That's what I did, then the criterion called here in: https://github.com/lyuwenyu/RT-DETR/blob/64878acad2f58ed34579e5a5ec45da1044587e09/rtdetr_pytorch/src/solver/det_engine.py#L123 will compare scaled output boxes coordinates similar to the model output as in https://github.com/lyuwenyu/RT-DETR/blob/64878acad2f58ed34579e5a5ec45da1044587e09/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L555 If I am not mistaken .. and model doesn't treat boxes differently in validation than in training

20231211 commented 4 months ago

@hashJoe Thank you so much!

amndzr commented 4 months ago

@20231211 No I didnt continue working on it. However, I trained the model not on my custom dataset but on the COCO2017 Dataset to see if the loss is still consistant. What I found was, that the validation Loss was still consistant, but when testing the model on unseen data the results were good so I didnt bother.

20231211 commented 4 months ago

@amndzr I tried the above workaround, adding {type: ConvertBox, out_fmt: 'cxcywh', normalize: True}, and the test loss drops and is on an order of magnitude higher than the training loss.