Open sungerk opened 2 months ago
#-- coding:utf8 -- import MNN import MNN.numpy as np import MNN.cv as cv2 def inference(model, img, precision, backend, thread): config = {} config['precision'] = precision config['backend'] = backend config['numThread'] = thread rt = MNN.nn.create_runtime_manager((config,)) net = MNN.nn.load_module_from_file(model, [], [], runtime_manager=rt) original_image = cv2.imread(img) ih, iw, _ = original_image.shape length = max((ih, iw)) scale = length / 640 image = np.pad(original_image, [[0, length - ih], [0, length - iw], [0, 0]], 'constant') image = cv2.resize(image, (640, 640), 0., 0., cv2.INTER_LINEAR, -1, [0., 0., 0.], [1./255., 1./255., 1./255.]) input_var = np.expand_dims(image, 0) input_var = MNN.expr.convert(input_var, MNN.expr.NC4HW4) output_var = net.forward(input_var) output_var = MNN.expr.convert(output_var, MNN.expr.NCHW) output_var = output_var.squeeze() # output_var shape: [85, 8400]; 85 means: [cx, cy, w, h, theta, prob * 80] cx = output_var[0] cy = output_var[1] w = output_var[2] h = output_var[3] theta = output_var[4] probs = output_var[5:] # Calculate oriented bounding boxes boxes = np.stack([cx, cy, w, h, theta], axis=1) scores = np.max(probs, 0) class_ids = np.argmax(probs, 0) result_ids = MNN.expr.nms(boxes, scores, 10, 0.5, 0.5) print(result_ids.shape) result_boxes = boxes[result_ids] result_scores = scores[result_ids] result_class_ids = class_ids[result_ids] for i in range(len(result_boxes)): cx, cy, w, h, theta = result_boxes[i].read_as_tuple() cx = int(cx * scale) cy = int(cy * scale) w = int(w * scale) h = int(h * scale) theta = float(theta) print(result_class_ids[i]) # Draw rotated rectangle rect = ((cx, cy), (w, h), theta) box = cv2.boxPoints(rect) box = np.int0(box) cv2.drawContours(original_image, [box], 0, (0, 0, 255), 2) cv2.imwrite('res.jpg', original_image) if __name__ == "__main__": model = 'obb.mnn' img = 'a.png' precision = 'normal' backend = 'CPU' thread = 4 inference(model, img, precision, backend, thread)
However, it cannot accurately extract the content
搞定了
https://github.com/alibaba/MNN/issues/3054
However, it cannot accurately extract the content