TexasInstruments / edgeai-yolov5

YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite. Forked from https://ultralytics.com/yolov5
https://github.com/TexasInstruments/edgeai
GNU General Public License v3.0
656 stars 119 forks source link

Loss formula in paper "YOLO-Pose: Enhancing YOLO for Multi Person Pose Estimation Using Object Keypoint Similarity Loss" #124

Open Charmnut1 opened 1 year ago

Charmnut1 commented 1 year ago

Hello, I would like to ask about the keypoint loss issue in YOLOPose, which is crucial for me. From the source code, we can see that in the loss.py file, it uses exp(-e), but there is no negative sign in the formula, which is different from OKS formula!! According to the formula, if d be larger, then the result of exp() will also be larger. In that case, the loss will be smaller, that's wrong? Am I misunderstanding something in the formula? I am looking forward to your assistance. Thank you very much! loss.py: it's 1-torch.exp(-e)

  def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
        """Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""
        d = (pred_kpts[..., 0] - gt_kpts[..., 0]) ** 2 + (pred_kpts[..., 1] - gt_kpts[..., 1]) ** 2
        kpt_loss_factor = (torch.sum(kpt_mask != 0) + torch.sum(kpt_mask == 0)) / (torch.sum(kpt_mask != 0) + 1e-9)
        # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9)  # from formula
        e = d / (2 * self.sigmas) ** 2 / (area + 1e-9) / 2  # from cocoeval
        return kpt_loss_factor * ((1 - torch.exp(-e)) * kpt_mask).mean()

loss formula in paper: LOSS = 1-(∑〖exp⁡((d_n^2)/(2s^2 k_n^2 ))δ(v_n>0)〗)/(∑δ(vn>0))

but OKS is: OKS = (∑〖exp⁡((-d_n^2)/(2s^2 σ_n^2 ))δ(v_n>0)〗)/(∑〖δ(v_n>0)〗)

MR-STUZHANG commented 11 months ago

您好,这个问题,我也注意到了。不知道现在您解决了吗

Charmnut1 commented 11 months ago

您好,这个问题,我也注意到了。不知道现在您解决了吗 还没有解决

xhf3571 commented 1 day ago

您好,这个问题,我也注意到了。不知道现在您解决了吗 还没有解决

Hello, may I ask if the function def forward(self, pred_kpts, gt_kpts, kpt_mask, area) is located in utils/loss.py? I couldn't find the corresponding function.