wangzhaode / mnn-yolo

mnn yolo demos.
55 stars 7 forks source link

How to Detect Using YOLOv8-OBB Model #15

Open sungerk opened 1 month ago

sungerk commented 1 month ago
#-- coding:utf8 --
import MNN
import MNN.numpy as np
import MNN.cv as cv2

def inference(model, img, precision, backend, thread):
    config = {}
    config['precision'] = precision
    config['backend'] = backend
    config['numThread'] = thread
    rt = MNN.nn.create_runtime_manager((config,))
    net = MNN.nn.load_module_from_file(model, [], [], runtime_manager=rt)
    original_image = cv2.imread(img)
    ih, iw, _ = original_image.shape
    length = max((ih, iw))
    scale = length / 640
    image = np.pad(original_image, [[0, length - ih], [0, length - iw], [0, 0]], 'constant')
    image = cv2.resize(image, (640, 640), 0., 0., cv2.INTER_LINEAR, -1, [0., 0., 0.], [1./255., 1./255., 1./255.])
    input_var = np.expand_dims(image, 0)
    input_var = MNN.expr.convert(input_var, MNN.expr.NC4HW4)
    output_var = net.forward(input_var)
    output_var = MNN.expr.convert(output_var, MNN.expr.NCHW)
    output_var = output_var.squeeze()

    # output_var shape: [85, 8400]; 85 means: [cx, cy, w, h, theta, prob * 80]
    cx = output_var[0]
    cy = output_var[1]
    w = output_var[2]
    h = output_var[3]
    theta = output_var[4]
    probs = output_var[5:]

    # Calculate oriented bounding boxes
    boxes = np.stack([cx, cy, w, h, theta], axis=1)
    scores = np.max(probs, 0)
    class_ids = np.argmax(probs, 0)
    result_ids = MNN.expr.nms(boxes, scores, 10, 0.5, 0.5)
    print(result_ids.shape)

    result_boxes = boxes[result_ids]
    result_scores = scores[result_ids]
    result_class_ids = class_ids[result_ids]
    for i in range(len(result_boxes)):
        cx, cy, w, h, theta = result_boxes[i].read_as_tuple()
        cx = int(cx * scale)
        cy = int(cy * scale)
        w = int(w * scale)
        h = int(h * scale)
        theta = float(theta)
        print(result_class_ids[i])

        # Draw rotated rectangle
        rect = ((cx, cy), (w, h), theta)
        box = cv2.boxPoints(rect)
        box = np.int0(box)
        cv2.drawContours(original_image, [box], 0, (0, 0, 255), 2)

    cv2.imwrite('res.jpg', original_image)

if __name__ == "__main__":
    model = 'obb.mnn'
    img = 'a.png'
    precision = 'normal'
    backend = 'CPU'
    thread = 4

    inference(model, img, precision, backend, thread)

However, it cannot accurately extract the content

sungerk commented 1 week ago

搞定了

https://github.com/alibaba/MNN/issues/3054