lyuwenyu / RT-DETR

[CVPR 2024] Official RT-DETR (RTDETR paddle pytorch), Real-Time DEtection TRansformer, DETRs Beat YOLOs on Real-time Object Detection. 🔥 🔥 🔥
Apache License 2.0
2.31k stars 258 forks source link

后处理问题 #317

Open WuShaogui opened 4 months ago

WuShaogui commented 4 months ago

https://github.com/lyuwenyu/RT-DETR/blob/2b88d5d53bcbfbb70329bc9c007fdf7e76cf90dc/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py#L53

这里使用是模型直接输出结果,未乘上图片尺寸,paddle版本无论是否use_focal_loss,都是使用还原尺寸的,猜测这里应该也是

lyuwenyu commented 4 months ago

前面已经乘过了

https://github.com/lyuwenyu/RT-DETR/blob/2b88d5d53bcbfbb70329bc9c007fdf7e76cf90dc/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py#L38

WuShaogui commented 4 months ago

前面乘了后面没用,相当于没乘

lyuwenyu commented 4 months ago

哦 变量名不对 确实是个bug..

https://github.com/lyuwenyu/RT-DETR/pull/319

YoungsunPan commented 4 months ago

不使用use_focal_loss(用celoss),你有训练成功么。

WuShaogui commented 4 months ago

不使用use_focal_loss(用celoss),你有训练成功么。

可以收敛,但是同样位置还是有问题,记得把以下

https://github.com/lyuwenyu/RT-DETR/blob/2b88d5d53bcbfbb70329bc9c007fdf7e76cf90dc/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py#L48

改为

scores = F.softmax(logits, dim=-1)

不然最后一个类别无法预测,因为结果被截断了,调试结果如下:

image

lyuwenyu commented 4 months ago

这里你训练数据是多少类别? @WuShaogui

WuShaogui commented 4 months ago

这里你训练数据是多少类别? @WuShaogui

训练自己的数据集,就2个类别

lyuwenyu commented 4 months ago

你这个https://github.com/lyuwenyu/RT-DETR/pull/320 修改之后有跑精度验证没 精度符合预期嘛

WuShaogui commented 4 months ago

你这个#320 修改之后有跑精度验证没 精度符合预期嘛

精度符合预期,但是数据集只有50张,类别只有2个,单图目标少于5个,整体比较小,可能没参考意义