ultralytics / ultralytics

NEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLite
https://docs.ultralytics.com
GNU Affero General Public License v3.0
28.61k stars 5.68k forks source link

After adding post-processing to the model output, repackage and export the new model #16176

Open canhe173 opened 1 week ago

canhe173 commented 1 week ago

Search before asking

Question

Generally, for segmentation models in yolov8, it needs to post-process the model output before we can get information like boxes, classes, scores. Therefore, I want to repackage this post-processing part behind the original output of the model, and finally export a new model containing 4 outputs, which includes boxes, classes, scores, and masks.

Below is my code, it cannot export model with new outputs and it still export the original model. Maybe I need some help.

import torch
import torchvision
from ultralytics import YOLO

class YoloModelWithPostProcess(YOLO):
    def __init__(self, model_path):
        super(YoloModelWithPostProcess, self).__init__(model_path)
        self.YOLO = YOLO(model_path)
        self.YOLO.model.eval()

    def forward(self, x):
        with torch.no_grad():
            yolo_output = self.YOLO.model(x)
            boxes, confidences, class_probs, masks = self.postprocess(yolo_output)
        return boxes, confidences, class_probs, masks

    def postprocess(self, preds, conf_threshold=0.025, iou_threshold = 0.045, nm=32):

        x, protos = preds[0], preds[1]
        x = x.permute(0, 2, 1)
        # Predictions filtering by conf-threshold
        x = x[torch.amax(x[..., 4:-nm], axis=-1) > conf_threshold]
        # Create a new matrix which merge these(box, score, cls, nm) into one
        scores = torch.amax(x[..., 4:-nm], dim=-1)
        class_indices = torch.argmax(x[..., 4:-nm], dim=-1)
        x = torch.cat([x[..., :4], scores.unsqueeze(-1), class_indices.unsqueeze(-1), x[..., -nm:]], dim=-1)
        # NMS filtering
        x = x[torchvision.ops.nms(x[:, :4], x[:, 4], iou_threshold).numpy()]
        # Decode and return
        if len(x) > 0:
            # Bounding boxes format change: cxcywh -> xyxy
            x[..., [0, 1]] -= x[..., [2, 3]] / 2
            x[..., [2, 3]] += x[..., [0, 1]]

            boxes_output = x[..., :4]
            confidences_output = x[..., 4:5]
            class_probs_output = x[..., 5:6]
            # Process masks
            c, mh, mw = protos[2][0].shape
            masks = torch.matmul(x[:, 6:], protos.view(c, -1)).view((-1, mh, mw)).permute(1, 2, 0)  # HWN
            masks_output = torch.einsum("HWN -> NHW", masks)  # HWN -> NHW

            return boxes_output, confidences_output, class_probs_output, masks_output

model_with_postprocess = YoloModelWithPostProcess('yolov8n-seg.pt')
model_with_postprocess.export(format="onnx")

Additional

No response

Y-T-G commented 1 week ago
  1. You have to override the forward of YOLO.model and not YOLO.
  2. Adding NMS to the model's graph is not so straightforward.