Melody-Zhou / tensorRT_Pro-YOLOv8

This repository is based on shouxieai/tensorRT_Pro, with adjustments to support YOLOv8.
MIT License
240 stars 40 forks source link

根据文档改写 源码,导出OBB模型时出错 #15

Open zhenhuamo opened 7 months ago

zhenhuamo commented 7 months ago

报这样错误 def forward(self, x): """Concatenates and returns predicted bounding boxes and class probabilities.""" bs = x[0].shape[0] # batch size angle = torch.cat([self.cv4i.view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits

NOTE: set angle as an attribute so that decode_bboxes could use it.

    angle = (angle.sigmoid() - 0.25) * math.pi  # [-pi/4, 3pi/4]
    # angle = angle.sigmoid() * math.pi / 2  # [0, pi/2]
    if not self.training:
        self.angle = angle
    x = self.detect(self, x)
    if self.training:
        return x, angle
    #return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
    # 修改为:
    return torch.cat([x, angle], 1).permute(0, 2, 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))

Sizes of tensors must match except in dimension 1. Expected size 19 but got size 8400 for tensor number 1 in the list.

Melody-Zhou commented 7 months ago

我刚简单测试了最新的 YOLOv8 代码导出正常,只修改了 ultralytics/nn/modules/head.py 第 141 行,修改如下:

# return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
return torch.cat([x, angle], 1).permute(0, 2, 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))

导出的 export.py 文件如下:

from ultralytics import YOLO

model = YOLO("yolov8s-obb.pt")

success = model.export(format="onnx", dynamic=True, simplify=True)

输出如下:

Ultralytics YOLOv8.1.40 🚀 Python-3.8.16 torch-1.12.1 CPU (12th Gen Intel Core(TM) i5-12400F)
YOLOv8s-obb summary (fused): 187 layers, 11417376 parameters, 0 gradients, 29.4 GFLOPs

PyTorch: starting from 'yolov8s-obb.pt' with input shape (1, 3, 1024, 1024) BCHW and output shape(s) (1, 21504, 20) (22.2 MB)

ONNX: starting export with onnx 1.13.1 opset 10...
ONNX: simplifying with onnxsim 0.4.35...
ONNX: export success ✅ 6.3s, saved as 'yolov8s-obb.onnx' (43.6 MB)

Export complete (9.4s)
Results saved to C:\Users\Admin\Desktop\test\ultralytics-main
Predict:         yolo predict task=obb model=yolov8s-obb.onnx imgsz=1024  
Validate:        yolo val task=obb model=yolov8s-obb.onnx imgsz=1024 data=runs/DOTAv1.0-ms.yaml  
Visualize:       https://netron.app

你可以尝试 clone 下最新的代码再导出看看