sophgo / tpu-mlir

Machine learning compiler based on MLIR for Sophgo TPU.
Other
613 stars 154 forks source link

为您提供一个tpu-mlir/python/sample/detect_yolov8.py的执行脚本 #177

Open wlc952 opened 4 months ago

wlc952 commented 4 months ago
#!/usr/bin/env python3
# Copyright (C) 2022 Sophgo Technologies Inc.  All rights reserved.
#
# TPU-MLIR is licensed under the 2-Clause BSD License except for the
# third-party components.
#
# ==============================================================================
try:
    from tpu_mlir.python import *
except ImportError:
    pass

import numpy as np
import os
import sys
import argparse
import cv2
from tools.model_runner import mlir_inference, model_inference, onnx_inference, torch_inference
from utils.preprocess import supported_customization_format

classes = {
    0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat',
    9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat',
    16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe',
    24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis',
    31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard',
    37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife',
    44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot',
    52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed',
    60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard',
    67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book',
    74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'
}

class YOLOv8:
    def __init__(self, model_path, net_input_shape, input_image, confidence_thres, iou_thres):
        self.model_path = model_path
        self.input_image = input_image
        self.confidence_thres = confidence_thres
        self.iou_thres = iou_thres
        self.input_size = tuple(map(int, args.net_input_dims.split(',')))
        self.classes = classes
        self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
        self.pixel_format = 'rgb'
        self.channel_format = 'nchw'
        self.img = cv2.imread(self.input_image)

    def draw_detections(self, img, box, score, class_id):
        x1, y1, w, h = box
        color = self.color_palette[class_id]
        cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
        label = f"{self.classes[class_id]}: {score:.2f}"
        (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
        label_x = x1
        label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
        cv2.rectangle(
            img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED
        )
        cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
        return img

    def preproc(self):
        img = self.img
        if len(img.shape) == 3:
            padded_img = np.ones((self.input_size[0], self.input_size[1], 3), dtype=np.uint8) * 114  # 114
        else:
            padded_img = np.ones(self.input_size, dtype=np.uint8) * 114  # 114

        r = min(self.input_size[0] / img.shape[0], self.input_size[1] / img.shape[1])

        resized_img = cv2.resize(
            img,
            (int(img.shape[1] * r), int(img.shape[0] * r)),
            interpolation=cv2.INTER_LINEAR,
        ).astype(np.uint8)
        top = int((self.input_size[0] - int(img.shape[0] * r)) / 2)
        left = int((self.input_size[1] - int(img.shape[1] * r)) / 2)
        padded_img[top:int(img.shape[0] * r) + top, left:int(img.shape[1] * r) + left] = resized_img

        if self.channel_format == 'nchw':
            padded_img = padded_img.transpose((2, 0, 1))  # HWC to CHW
        if self.pixel_format == 'rgb':
            padded_img = padded_img[::-1]  # BGR to RGB

        padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)

        return padded_img, r, top, left

    def postproc(self, output, r, top, left):
        def getInter(box1, box2):
            box1_x1, box1_y1, box1_x2, box1_y2 = box1[0] - box1[2] / 2, box1[1] - box1[3] / 2, \
                                                box1[0] + box1[2] / 2, box1[1] + box1[3] / 2
            box2_x1, box2_y1, box2_x2, box2_y2 = box2[0] - box2[2] / 2, box2[1] - box1[3] / 2, \
                                                box2[0] + box2[2] / 2, box2[1] + box2[3] / 2
            if box1_x1 > box2_x2 or box1_x2 < box2_x1:
                return 0
            if box1_y1 > box2_y2 or box1_y2 < box2_y1:
                return 0
            x_list = [box1_x1, box1_x2, box2_x1, box2_x2]
            x_list = np.sort(x_list)
            x_inter = x_list[2] - x_list[1]
            y_list = [box1_y1, box1_y2, box2_y1, box2_y2]
            y_list = np.sort(y_list)
            y_inter = y_list[2] - y_list[1]
            inter = x_inter * y_inter
            return inter
        def getIou(box1, box2):
            inter_area = getInter(box1, box2)
            box1_area = box1[2] * box1[3]
            box2_area = box2[2] * box2[3]
            union = box1_area + box2_area - inter_area
            iou = inter_area / union
            return iou

        img = self.img
        pred = np.transpose(np.squeeze(output[0]))
        pred_class = pred[..., 4:]
        pred_conf = np.max(pred_class, axis=-1)
        pred = np.insert(pred, 4, pred_conf, axis=-1)

        conf = pred[..., 4] > self.confidence_thres
        true_pred = pred[conf]
        true_cls_score = true_pred[..., 5:]

        all_cls = [int(np.argmax(cls_scores)) for cls_scores in true_cls_score]
        classes = list(set(all_cls))

        output_boxes = []
        for cls in classes:
            clss_mask = np.array(all_cls) == cls
            clss_boxes = true_pred[clss_mask][:, :6]
            clss_boxes[:, 5] = cls
            clss_boxes = clss_boxes[np.argsort(clss_boxes[:, 4])[::-1]]

            while len(clss_boxes) > 0:
                current_box = clss_boxes[0]
                output_boxes.append(current_box)
                if len(clss_boxes) == 1:
                    break
                ious = np.array([getIou(current_box, box) for box in clss_boxes[1:]])
                clss_boxes = clss_boxes[1:][ious < self.iou_thres]

        for bb in output_boxes: 
            x, y, w, h  = bb[:4]
            box_x = int((x - w / 2 - left) / r)
            box_y = int((y - h / 2- top) / r)
            box_width = int(w / r)
            box_height = int(h / r )        
            box = [box_x, box_y, box_width, box_height]
            score = bb[4]
            class_id = int(bb[5])
            img = self.draw_detections(img, box, score, class_id)

        return img

    def main(self):
        img, ratio, top, left = self.preproc()
        img = np.expand_dims(img, axis=0)
        img /= 255. 

        data = {"data": img}  # input name from model
        output = dict()
        if args.model.endswith('.onnx'):
            output = onnx_inference(data, args.model, False)
        elif args.model.endswith('.pt') or args.model.endswith('.pth'):
            output = torch_inference(data, args.model, False)
        elif args.model.endswith('.mlir'):
            output = mlir_inference(data, args.model, False)
        elif args.model.endswith(".bmodel"):
            output = model_inference(data, args.model)
        elif args.model.endswith(".cvimodel"):
            output = model_inference(data, args.model, False)
        else:
            raise RuntimeError("not support modle file:{}".format(args.model))
        outputs = next(iter(output.values()))
        return self.postproc(outputs, ratio, top, left) 

def parse_args():
    parser = argparse.ArgumentParser(description='Inference Yolo v5 network.')
    parser.add_argument("--model", type=str, required=True, help="Model definition file")
    parser.add_argument("--net_input_dims", type=str, default="640,640", help="(h,w) of net input")
    parser.add_argument("--input", type=str, required=True, help="Input image for testing")
    parser.add_argument("--output", type=str, required=True, help="Output image after detection")
    parser.add_argument("--conf_thres", type=float, default=0.5, help="Confidence threshold")
    parser.add_argument("--iou_thres", type=float, default=0.6, help="NMS IOU threshold")

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    input_shape = tuple(map(int, args.net_input_dims.split(',')))
    detection = YOLOv8(args.model, args.net_input_dims, args.input, args.conf_thres, args.iou_thres)
    output_image = detection.main()
    cv2.imwrite(args.output, output_image)