yzqxy / Yolov8_obb_Prune_Track

GNU General Public License v3.0
176 stars 13 forks source link

decode result incorrect #28

Closed Handsome-cp closed 10 months ago

Handsome-cp commented 10 months ago

I train my data with yolov8 obb detector, and I tested the .pt, the result are correct. (However, the regression of angle was not so good)

After that I export the model to onnx, my export code of 'forward ' is:

def forward_export(self, x):
        shape = x[0].shape  # BCHW
        for i in range(self.nl):
            x[i] = torch.cat((self.cv2[i](x[i]),self.cv4[i](x[i]), self.cv3[i](x[i]) ), 1)
        if self.dynamic or self.shape != shape:
            self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
            self.shape = shape

        #dfl_box
        box,theta, cls = torch.cat([xi.view(shape[0], self.no_box, -1) for xi in x], 2).split((self.reg_max * 4, self.theta ,self.nc), 1)
        dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides

        theta_pred = (theta.sigmoid()- 0.5) *math.pi # [n_conf_thres, 1] θ ∈ [-pi/2, pi/2) 
        N, _, Fsize= cls.shape
        cls_scores = cls.sigmoid().permute(0, 2, 1).reshape(N, Fsize, -1)
        bbox_preds = dbox.permute(0, 2, 1).reshape(N, Fsize, -1)
        angle_preds = theta_pred.permute(0, 2, 1).reshape(N, Fsize, -1)
        return torch.cat([bbox_preds, angle_preds,cls_scores], dim=-1)

When decoding the result of inference, the w,h value (the 3, 4 value in bbox_preds) are vering big (more than 600 for plane objects). Would you like to help me find out what's the problem.